我正在尝试使用 torchtext 中的 BucketIterator.splits 函数从 csv 文件中加载数据以在 CNN 中使用。除非我有一批最长的句子比最大的过滤器尺寸短,否则一切正常。
在我的示例中,我有大小为 3、4 和 5 的过滤器,因此如果最长的句子没有至少 5 个单词,我会收到错误消息。有没有办法让 BucketIterator 为批次动态设置填充,同时设置最小填充长度?
这是我用于 BucketIterator 的代码:
train_iter, val_iter, test_iter = BucketIterator.splits((train, val, test), sort_key=lambda x: len(x.text), batch_size=batch_size, repeat=False, device=device)
我希望有一种方法可以在 sort_key 或类似的东西上设置最小长度?
我试过这个但它不起作用:
FILTER_SIZES = [3,4,5]
train_iter, val_iter, test_iter = BucketIterator.splits((train, val, test), sort_key=lambda x: len(x.text) if len(x.text) >= FILTER_SIZES[-1] else FILTER_SIZES[-1], batch_size=batch_size, repeat=False, device=device)