我最近开始使用 torchtext 来替换我的胶水代码,我遇到了一个问题,我想在我的架构中使用注意力层。为了做到这一点,我需要知道我的训练数据的最大序列长度。
问题在于,torchtext.data.BucketIterator
它会按批次进行填充:
# All 4 examples in the batch will be padded to maxlen in the batch
train_iter = torchtext.data.BucketIterator(dataset=train, batch_size=4)
是否有某种方法可以确保所有训练示例都填充到相同的长度;即训练中的maxlen?