我的“标签”字段是一个长度为 201 的单热向量。但是,我无法使用这种单热表示创建迭代器。如果我尝试迭代迭代器,我会收到以下错误。
from torchtext.data import Field
from torchtext.data import TabularDataset
from torchtext.data import Iterator, BucketIterator
tokenize = lambda x: x.split()
TEXT = Field(sequential=True, tokenize=tokenize, lower=True)
LABEL = Field(sequential=True, use_vocab=False)
datafields = [("text", TEXT), ("label", LABEL)]
train, test = TabularDataset.splits(
path = '/home/karthik/Documents/Deep_Learning/73Strings/',
train = "train.csv", validation="test.csv",
format='csv',
skip_header=True,
fields=datafields)
train_iter, val_iter = BucketIterator.splits(
(train, test), # we pass in the datasets we want the iterator to draw data from
batch_sizes=(64, 64),
device=device, # if you want to use the GPU, specify the GPU number here
sort_key=lambda x: len(x.text), # the BucketIterator needs to be told what function it should use to group the data.
sort_within_batch=False,
repeat=False # we pass repeat=False because we want to wrap this Iterator layer.
)
test_iter = Iterator(test, batch_size=64, sort=False, sort_within_batch=False, repeat=False)
for batch in train_iter:
print(batch)
() 中的 ValueError Traceback (最近一次调用最后一次) ----> 1 用于 train_iter 中的批处理:2 print(batch)
/usr/local/lib/python3.6/dist-packages/torchtext/data/iterator.py in iter (self) 155 else: 156 minibatch.sort(key=self.sort_key, reverse=True) --> 157 yield Batch(minibatch, self.dataset, self.device) 158 如果不是 self.repeat: 159 返回
/usr/local/lib/python3.6/dist-packages/torchtext/data/batch.py in init (self, data, dataset, device) 32 如果字段不是 None: 33 batch = [getattr(x, name) for x in data] ---> 34 setattr(self, name, field.process(batch, device=device)) 35 36 @classmethod
/usr/local/lib/python3.6/dist-packages/torchtext/data/field.py in process(self, batch, device) 199 """ 200 padded = self.pad(batch) --> 201 tensor = self.numericalize(padded, device=device) 202 返回张量 203
/usr/local/lib/python3.6/dist-packages/torchtext/data/field.py in numericize(self, arr, device) 321 arr = self.postprocessing(arr, None) 322 --> 323 var = torch .tensor(arr, dtype=self.dtype, device=device) 324 325 如果 self.sequential 而不是 self.batch_first:
ValueError:太多维度“str”