解决方案是使用torchtext.data.functional.to_map_style_dataset(iter_data)
(官方文档)将您的可迭代样式数据集转换为地图样式数据集。
像这样:
from torchtext.data.functional import to_map_style_dataset
train_iter = IMDB(split='train')
train_dataset = to_map_style_dataset(train_iter) #Map-style dataset
然后制作一个数据加载器。
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=64, collate_fn=collate_fn)
为什么会这样?
我使用上面示例的命名约定来解释。
train_iter
传递给的Dataloader
是一个 Iterable 样式的数据集,这意味着它没有__getitem__
实现。它只有__iter__
和__next__
dunders - 这使它成为可迭代的。
因此,如果我将一个可迭代对象传递给Dataloader
,则数据加载器会在异常发生后停止-当数据集(可迭代对象)耗尽时,可迭代样式数据集(在这种情况下)的 dunderStopIteration
将抛出异常。__next__
train_iter
因此,我们使用该to_map_style_dataset
函数将 Iterable-style 转换为 map-style 数据集。它通过实现一个__getitem__
dunder 来实现,因此Dataloader
默认使用索引从数据集中获取项目。
做同样事情的另一种可能的方式也可以是
如果我要使用可迭代式数据集 - 我需要Dataloader
在每个时期创建对象。因此,在每个 epoch 之后,新的数据加载器对象将在 for 循环中从头开始运行。
为了更好地理解 Pytorch 中 Iterable 样式和 Map 样式数据集的区别和用例,请参阅此https://yizhepku.github.io/2020/12/26/dataloader.html