我正在使用 Pytorch(1.10 v),我正在使用 Multi30k 德语到英语数据集进行机器翻译。我正在使用 spacy 进行标记化(英语和德语单词),并希望将标记化的数据传递给(torchtext.legacy.data.BucketIterator)以进行填充并将字符串转换为索引。发生与 sort_key 相关的一些错误,我没有得到它。有人,请帮帮我。
编码
import spacy
from torchtext.datasets import Multi30k # this is a en and gr dataset for machine translation
from torchtext.legacy.data import Field, BucketIterator
spacy_eng = spacy.load("en_core_web_sm")
spacy_ger = spacy.load("de_core_news_sm")
def tokenize_eng(text):
return [tok.text for tok in spacy_eng.tokenizer(text)]
def tokenize_ger(text):
return [tok.text for tok in spacy_ger.tokenizer(text)]
english = Field(sequential=True, use_vocab=True, tokenize=tokenize_eng, lower=True, init_token='<sos>', eos_token='<eos>')
german = Field(sequential=True, use_vocab=True, tokenize=tokenize_ger, lower=True, init_token='<sos>', eos_token='<eos>')
train, valid, test = Multi30k(root=".data", split=('train', 'valid', 'test'), language_pair=('en', 'de'))
# will make vocabulary from train data
english.build_vocab(train, max_size=10000, min_freq=2)
german.build_vocab(train, max_size=10000, min_freq=2)
train_data, valid_data, test_data = BucketIterator.splits((train, valid, test),
batch_size=64,
device='cuda')
错误
Traceback (most recent call last):
File "D:\Torch\Multi30K_inbuilt_dataset.py", line 28, in <module>
train_data, valid_data, test_data = BucketIterator.splits((train, valid, test),
File "C:\Users\Devanshu\anaconda3\envs\deeplearning\lib\site-packages\torchtext\legacy\data\iterator.py", line 99, in splits
ret.append(cls(
File "C:\Users\Devanshu\anaconda3\envs\deeplearning\lib\site-packages\torchtext\legacy\data\iterator.py", line 59, in __init__
self.sort_key = dataset.sort_key
File "C:\Users\Devanshu\anaconda3\envs\deeplearning\lib\site-packages\torch\utils\data\dataset.py", line 226, in __getattr__
raise AttributeError
AttributeError