我正在使用 HuggingFace 预训练模型通过和facebook/bart-large-cnn
进行文本摘要。模型和标记器都加载正常:AutoModel
AutoTokenizer
import os
import torch
from transformers import AutoTokenizer, AutoModel
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn",
cache_dir=os.getenv("cache_dir", "model"))
model = AutoModel.from_pretrained("facebook/bart-large-cnn",
cache_dir=os.getenv("cache_dir", "model")).to(torch_device)
FRANCE_ARTICLE = ' Marseille...' # @noqa
dct = tokenizer.batch_encode_plus(
[FRANCE_ARTICLE],
max_length=1024,
padding="max_length",
truncation=True,
return_tensors="pt",
)
max_length = 140
min_length = 55
hypotheses_batch = model.generate(
input_ids=dct["input_ids"].to(torch_device),
attention_mask=dct["attention_mask"].to(torch_device),
num_beams=4,
length_penalty=2.0,
max_length=max_length + 2,
min_length=min_length + 1,
no_repeat_ngram_size=3,
do_sample=False,
early_stopping=True,
decoder_start_token_id=model.config.eos_token_id,
)
decoded = [
tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in hypotheses_batch
]
print(decoded)
但是当我在标记器上调用解码时出现此错误tokenizer.batch_encode_plus
:
Traceback (most recent call last):
File "src/summarization/run.py", line 42, in <module>
summary_ids = model.generate(article_input_ids,num_beams=4,length_penalty=2.0,max_length=142,min_length=56,no_repeat_ngram_size=3)
File "/usr/local/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.7/site-packages/transformers/generation_utils.py", line 379, in generate
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
AssertionError: BartModel(
(shared): Embedding(50264, 1024, padding_idx=1)
(encoder): BartEncoder(
(embed_tokens): Embedding(50264, 1024, padding_idx=1)
(embed_positions): LearnedPositionalEmbedding(1026, 1024, padding_idx=1)
(layers): ModuleList(
(0): EncoderLayer(
...
)
)
(layernorm_embedding): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
) should have a 'get_encoder' function defined