0

我正在使用 Pytorch(1.10 v),我正在使用 Multi30k 德语到英语数据集进行机器翻译。我正在使用 spacy 进行标记化(英语和德语单词),并希望将标记化的数据传递给(torchtext.legacy.data.BucketIterator)以进行填充并将字符串转换为索引。发生与 sort_key 相关的一些错误,我没有得到它。有人,请帮帮我。

编码

import spacy
from torchtext.datasets import Multi30k # this is a en and gr dataset for machine translation
from torchtext.legacy.data import Field, BucketIterator

spacy_eng = spacy.load("en_core_web_sm")
spacy_ger = spacy.load("de_core_news_sm")

def tokenize_eng(text):
    return [tok.text for tok in spacy_eng.tokenizer(text)]

def tokenize_ger(text):
    return [tok.text for tok in spacy_ger.tokenizer(text)]

english = Field(sequential=True, use_vocab=True, tokenize=tokenize_eng, lower=True, init_token='<sos>', eos_token='<eos>')
german = Field(sequential=True, use_vocab=True, tokenize=tokenize_ger, lower=True, init_token='<sos>', eos_token='<eos>')

train, valid, test = Multi30k(root=".data", split=('train', 'valid', 'test'), language_pair=('en', 'de'))

# will make vocabulary from train data
english.build_vocab(train, max_size=10000, min_freq=2)
german.build_vocab(train, max_size=10000, min_freq=2)


train_data, valid_data, test_data = BucketIterator.splits((train, valid, test),
                                                          batch_size=64,
                                                          device='cuda')

错误

Traceback (most recent call last):
  File "D:\Torch\Multi30K_inbuilt_dataset.py", line 28, in <module>
    train_data, valid_data, test_data = BucketIterator.splits((train, valid, test),
  File "C:\Users\Devanshu\anaconda3\envs\deeplearning\lib\site-packages\torchtext\legacy\data\iterator.py", line 99, in splits
    ret.append(cls(
  File "C:\Users\Devanshu\anaconda3\envs\deeplearning\lib\site-packages\torchtext\legacy\data\iterator.py", line 59, in __init__
    self.sort_key = dataset.sort_key
  File "C:\Users\Devanshu\anaconda3\envs\deeplearning\lib\site-packages\torch\utils\data\dataset.py", line 226, in __getattr__
    raise AttributeError
AttributeError
4

1 回答 1

0

希望它不会太晚,但尝试使用:

train, valid, test = Multi30k.splits(exts=('.de', '.en'), fields=(german, english))
于 2021-12-18T21:24:23.390 回答