我为 PyTorch CNN 项目编写了自定义数据集和数据加载器。这是数据集的相关代码
class MyDataset(Dataset):
def __init__(self):
pass
def __len__(self):
return COUNT
def __getitem__(self, idx):
x, y = X[idx], Y[idx]
x = image_augment(x) # custom func to resize image to 32x32
return x, y
每次训练的形状x
都是[4, 32, 32, 3]
.
这是我的网络代码,直接取自这个 PyTorch 示例。
class Net(nn.Module):
def __init__(self, nc):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, nc)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
当我尝试用我的 DataLoader 中的数据训练这个网络时,我得到了错误语句Given groups=1, weight of size [6, 3, 5, 5], expected input[4, 32, 32, 3] to have 3 channels, but got 200 channels instead
。在我看来,我的问题是来自我的 DataLoader 使用的数据的形状x.view(4, 3, 32, 32)
,但后来我收到一个错误,说我couldn't use Conv2D on a ByteTensor
。我在这里有点迷路,非常感谢任何帮助。谢谢!