我在collate_fn
为 PyTorchDataLoader
类编写自定义函数时遇到问题。我需要自定义函数,因为我的输入具有不同的维度。
我目前正在尝试编写斯坦福 MURA 论文的基线实现。数据集有一组带标签的研究。一项研究可能包含多个图像。我创建了一个自定义Dataset
类,使用torch.stack
.
然后将堆叠张量作为输入提供给模型,并对输出列表进行平均以获得单个输出。此实现适用于DataLoader
when batch_size=1
。但是,当我尝试将 设置batch_size
为 8 时,就像原始论文中的情况一样,DataLoader
失败了,因为它用于torch.stack
堆叠批次并且我的批次中的输入具有可变尺寸(因为每个研究都可以有多个图像)。
为了解决这个问题,我尝试实现我的自定义collate_fn
函数。
def collate_fn(batch):
imgs = [item['images'] for item in batch]
targets = [item['label'] for item in batch]
targets = torch.LongTensor(targets)
return imgs, targets
然后在我的训练周期循环中,我像这样遍历每个批次:
for image, label in zip(*batch):
label = label.type(torch.FloatTensor)
# wrap them in Variable
image = Variable(image).cuda()
label = Variable(label).cuda()
# forward
output = model(image)
output = torch.mean(output)
loss = criterion(output, label, phase)
但是,这并没有给我任何改进的时代时间,并且仍然需要与只有 1 的批量大小一样长。我还尝试将批量大小设置为 32,但这也没有改善时间。
难道我做错了什么?有更好的方法吗?