1

我按照本教程 http://www.programmersought.com/article/2609385756/

使用已标记化并转换为 id 的数据创建 TabularDataset,我不想使用 vocab 或构建 vocab,因为数据是数字的

所以我将我的字段变量定义为:

myField = Field(tokenize= x_tokenize, use_vocab=False, sequential=True)
train,val, test = data.TabularDataset.splits(path='./', train=train_path, validation=valid_path, test=test_path ,format='csv', fields=data_fields, skip_header=True)

火车输出:

print(vars(train[0])['src'])
#output this [101, 3177, 3702, 11293, 1116, 102]

我使用了 BucketIterator:

train_iter= BucketIterator(train,
                       batch_size=BATCH_SIZE,
                       device = DEVICE,
                       sort_key=lambda x: (len(x.src), len(x.trg)), 
                       train=True,
                       batch_size_fn=batch_size_fn,
                       repeat=False)

当我运行此代码时:

batch = next(iter(train_iter))

我得到 TypeError: an integer is required (得到类型列表)


TypeError Traceback (最近一次调用最后一次) in () ----> 1 batch = next(iter(train_iter))

3 帧 /usr/local/lib/python3.6/dist-packages/torchtext/data/iterator.py in iter (self) 155 else: 156 minibatch.sort(key=self.sort_key, reverse=True) --> 157 yield Batch(minibatch, self.dataset, self.device) 158 if not self.repeat: 159 return

/usr/local/lib/python3.6/dist-packages/torchtext/data/batch.py​​ in init (self, data, dataset, device) 32 如果字段不是 None: 33 batch = [getattr(x, name) for x in data] ---> 34 setattr(self, name, field.process(batch, device=device)) 35 36 @classmethod

/usr/local/lib/python3.6/dist-packages/torchtext/data/field.py in process(self, batch, device) 199 """ 200 padded = self.pad(batch) --> 201 tensor = self.numericalize(padded, device=device) 202 返回张量 203

/usr/local/lib/python3.6/dist-packages/torchtext/data/field.py in numericize(self, arr, device) 321 arr = self.postprocessing(arr, None) 322 --> 323 var = torch .tensor(arr, dtype=self.dtype, device=device) 324 325 如果 self.sequential 而不是 self.batch_first:

TypeError:需要一个整数(获取类型列表)

4

1 回答 1

0

您必须在声明字段时提供 pad_token。

改变这个

myField = Field(tokenize= x_tokenize, use_vocab=False, sequential=True)

myField = Field(tokenize= x_tokenize, use_vocab=False, sequential=True, pad_token=0)

于 2020-06-26T17:11:26.940 回答