0

我想在情感分析任务上训练一个 RNN,对于这个任务,我使用了由 torchtext 提供的 IMDB 数据集,其中包含 50000 条电影评论,它是一个 python 迭代器。我用了一个split=('train', 'test').

我首先使用每个句子构建一个词汇torchtext.vocab.Vocab并对其进行标记,然后进行数字化。

为了将序列填充到我使用的相同长度,torch.nn.utils.rnn.pad_sequence并且还使用了 acollate_fnbatch_sampler. 然后我使用 torch.utils.data 加载数据。DataLoader.

RNN 网络的实现很好,但数据加载器在一个 epoch 后就耗尽了,如下图所示。

我是否遵循正确的方法来加载这个可迭代数据集?以及为什么数据加载器在一个时期后耗尽,我该如何克服这个问题。

如果您想查看我的实现,请参阅共享的 colab 笔记本。

PS。我正在关注来自github的torchtext的官方变更日志

你可以在这里找到我的实现

数据加载器在单个 epoch 后耗尽

在所附图像中,您可以看到数据加载器在单个 epoch 后耗尽。

4

1 回答 1

0

问题是您的数据加载器是一个生成器,并且在完全迭代后耗尽。一种解决方案是在每个时期初始化数据加载器。二是不要使用批量采样器。整理功能应该做你想做的事。

def collate_batch(batch):
    batch.sort(key=lambda x: len(x[0]), reverse=True)
    label_list, text_list, text_lengths = [], [], []
    
    for _text, _label in batch:
        label_list.append(_label)
        processed_text = torch.tensor(_text)
        text_list.append(processed_text)
        text_lengths.append(len(processed_text))

    return torch.tensor(label_list, dtype=torch.float32),
           pad_sequence(text_list, padding_value=3.0), 
           torch.tensor(text_lengths, dtype=torch.int64)
于 2021-06-21T20:00:48.393 回答