1

我想在情感分析任务上训练一个 RNN,对于这个任务,我使用了由 torchtext 提供的 IMDB 数据集,其中包含 50000 条电影评论,它是一个 python 迭代器。我用了一个split=('train', 'test').

我首先使用每个句子构建一个词汇torchtext.vocab.Vocab并对其进行标记,然后进行数字化。

为了将序列填充到我使用的相同长度,torch.nn.utils.rnn.pad_sequence并且还使用了 acollate_fnbatch_sampler. 然后我使用 torch.utils.data 加载数据。DataLoader.

RNN 网络的实现很好,但数据加载器在一个 epoch 后就耗尽了,如下图所示。

我是否遵循正确的方法来加载这个可迭代数据集?以及为什么数据加载器在一个时期后耗尽,我该如何克服这个问题。

如果您想查看我的实现,请参阅共享的 colab 笔记本。

PS。我正在关注来自github的torchtext的官方变更日志

你可以在这里找到我的实现

数据加载器在单个 epoch 后耗尽

4

1 回答 1

0

解决方案是使用torchtext.data.functional.to_map_style_dataset(iter_data)官方文档)将您的可迭代样式数据集转换为地图样式数据集。

像这样:

from torchtext.data.functional import to_map_style_dataset
train_iter = IMDB(split='train')
train_dataset = to_map_style_dataset(train_iter)  #Map-style dataset

然后制作一个数据加载器。

from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=64, collate_fn=collate_fn)

为什么会这样?

我使用上面示例的命名约定来解释。

train_iter传递给的Dataloader是一个 Iterable 样式的数据集,这意味着它没有__getitem__实现。它只有__iter____next__dunders - 这使它成为可迭代的。

因此,如果我将一个可迭代对象传递给Dataloader,则数据加载器会在异常发生后停止-当数据集(可迭代对象)耗尽时,可迭代样式数据集(在这种情况下)的 dunderStopIteration将抛出异常。__next__train_iter

因此,我们使用该to_map_style_dataset函数将 Iterable-style 转换为 map-style 数据集。它通过实现一个__getitem__dunder 来实现,因此Dataloader默认使用索引从数据集中获取项目。

做同样事情的另一种可能的方式也可以是

如果我要使用可迭代式数据集 - 我需要Dataloader在每个时期创建对象。因此,在每个 epoch 之后,新的数据加载器对象将在 for 循环中从头开始运行。

为了更好地理解 Pytorch 中 Iterable 样式和 Map 样式数据集的区别和用例,请参阅此https://yizhepku.github.io/2020/12/26/dataloader.html

于 2021-07-01T10:51:07.987 回答