我开始使用 pytorch-lightning 并遇到了自定义数据加载器的问题:
我使用自己的数据集和通用的 torch.utils.data.DataLoader。基本上,数据集采用路径并加载与数据加载器加载的给定索引相对应的数据。
def train_dataloader(self):
train_set = TextKeypointsDataset(parameters...)
train_loader = torch.utils.data.DataLoader(train_set, batch_size, num_workers)
return train_loader
当我使用 pytorch-lightning 模块时train_dataloader
,training_step
一切运行良好。当我添加val_dataloader
并validation_step
遇到此错误时:
Epoch 1: 45%|████▌ | 10/22 [00:02<00:03, 3.34it/s, loss=5.010, v_num=131199]
ValueError: Expected input batch_size (1500) to match target batch_size (5)
在这种情况下,我的数据集非常小(用于测试功能),只有 84 个样本,我的批量大小为 8。用于训练和验证的数据集具有相同的长度(仅用于测试目的)。
所以总共有 84 * 2 = 168 和 168 / 8 (batchsize) = 21,大致就是上面显示的总步数 (22)。这意味着在训练数据集上运行 10 次 (10 * 8 = 80) 后,加载器期望新的完整样本为 8,但由于只有 84 个样本,我得到一个错误(至少这是我目前的理解)。
我在自己的实现中遇到了类似的问题(不使用 pytorch-lighntning)并使用这种模式来解决它。基本上,当数据用完时,我正在重置迭代器:
try:
data = next(data_iterator)
source_tensor = data[0]
target_tensor = data[1]
except StopIteration: # reinitialize data loader if num_iteration > amount of data
data_iterator = iter(data_loader)
现在看来我面临类似的事情?当我的 training_dataloader 数据不足时,我不知道如何在 pytorch-lightning 中重置/重新初始化数据加载器。我想一定有另一种我不熟悉的复杂方式。谢谢