(更新了对问题的进一步见解)
我有一个包含 3000 张图像的数据集,它们通过以下几行进入 DataLoader:
training_left_eyes = torch.utils.data.DataLoader(train_dataset, batch_size=2,shuffle=True, drop_last=True)
print(len(training_left_eyes)) #Outputs 1500
我的训练循环如下所示:
for i,(data,output) in enumerate(training_left_eyes):
data,output = data.to(device),output.to(device)
prediction = net(data)
loss = costFunc(prediction,output)
closs = loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("batch #{}".format(i))
if i%100 == 0:
print('[%d %d] loss: %.8f' % (epoch+1,i+1,closs/1000))
closs = 0
张量“数据”和“输出”(标签)中的信息是正确的,系统工作正常,直到达到批号 1500。我所有的批次都满了 3000/2=1500,没有剩余。一旦到达最后一批,就会出现一个 RunTimeError 说明存在 0 维输入大小。但我不知道为什么会发生这种情况,因为 enumerate(training_left_eyes) 应该遍历已满的 DataLoader 的值。
我在网上搜索了如何解决这个问题,有些人提到 DataLoader 上的“drop_last=True”属性,虽然这样做是为了防止半空批次进入模型,但我还是尝试了但无济于事。
我开始过于狭隘了,似乎无法靠自己解决问题。我可以简单地插入一个 if 语句,但我认为这是非常糟糕的做法,我想学习正确的解决方案。
如果有帮助,这是我的自定义数据集:
class LeftEyeDataset(torch.utils.data.Dataset):
"""Left eye retinography dataset. Normal/Not-normal"""
def __init__(self, csv_file, root_dir, transform=None):
"""
Args:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.labels = label_mapping(csv_file)
self.root_dir = root_dir
self.transform = transform
self.names = name_mapping(csv_file)
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = self.root_dir +'/'+ self.names[idx]
image = io.imread(img_name)
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image,label
def label_mapping(csv_file) -> np.array:
df = read_excel(excel_file, 'Sheet1')
x= []
for key,value in df['Left-Diagnostic Keywords'].iteritems():
if value=='normal fundus':
x.append(1)
else:
x.append(0)
x_tensor = torch.LongTensor(x)
return x_tensor
def name_mapping(csv_file) -> list:
#Reads the names of the excel file
df = read_excel(excel_file, 'Sheet1')
names= list()
serie = df['Left-Fundus']
for i in range(serie.size):
names.append(df['Left-Fundus'][i])
return names
如果需要,我可以提供任何额外的代码。
更新:经过一段时间尝试解决问题后,我设法查明正在发生的事情。由于某种原因,在最后一批中,进入网络的数据很好,但在第一层发生之前,它就消失了。在下一张图片中,您可以看到我在进入forward(self,x)之前和之后所做的打印。尺寸对齐,直到批号 61(我在本示例中将其从 1500 减少),其中它以某种方式通过打印两次。在该行之后,出现上述错误。