5

当我尝试通过打印BucketIterator对象的下一次迭代来查看批次时,AttributeError会抛出 。

tv_datafields=[("Tweet",TEXT), ("Anger",LABEL), ("Fear",LABEL), ("Joy",LABEL), ("Sadness",LABEL)]
train, vld = data.TabularDataset.splits(path="./data/", train="train.csv",validation="test.csv",format="csv", fields=tv_datafields)

train_iter, val_iter = BucketIterator.splits(
(train, vld),
batch_sizes=(64, 64),
device=-1,
sort_key=lambda x: len(x.Tweet),
sort_within_batch=False,
repeat=False
)
print(next(iter(train_dl)))
4

1 回答 1

1

我不确定您遇到的具体错误,但在这种情况下,您可以使用以下代码迭代批处理:

for i in train_iter:
    print i.Tweet
    print i.Anger
    print i.Fear
    print i.Joy
    print i.Sadness

i.Tweet(还有其他)是 shape 的张量(input_data_length, batch_size)

因此,要查看单个批次数据(比如说批次 0),您可以执行print i.Tweet[:,0].

val_iter(和test_iter,如果需要)也是如此。

于 2018-10-09T02:33:57.900 回答