我正在尝试使用 HuggingFace 的 NLP 模型“MarianMT”,并且我想使用我自己的解码实现(例如:贪婪解码)。我将我的实现与 Hugging Face 实现进行了比较。(model.generate())
我实现了它;我检查了代码 100 次,但我不知道为什么我有不同的文本生成。你能帮我吗?
我的实现:
# Hugging face
from transformers import MarianMTModel, MarianTokenizer
import torch
import torch.nn.functional as F
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-de')
model = MarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-en-de').to(device)
vocab_size = tokenizer.vocab_size
max_t = 15
done = [False ] # To indicate if the generation is finished (we reached EOS token)
# My implementation of Greedy Search
source = [ 'While that might be a somewhat morbid thought, I think it has some really profound implications that are worth exploring.']
encoded = tokenizer.prepare_seq2seq_batch(source, return_tensors='pt').to(device)
generatd_tokens = torch.tensor([[58100] ]).to(device)
for t in range(1, max_t-1) :
model_output = model(**{'input_ids':encoded["input_ids"], "attention_mask" :encoded['attention_mask'],"decoder_input_ids":generatd_tokens})['logits'].detach()[:,-1,:].reshape(1, vocab_size )
distrib = F.softmax(model_output, dim =-1).reshape(1, vocab_size )
distrib = torch.sort(distrib, dim =-1,descending=True )
distrib_idx = distrib.indices
distrib_values = distrib.values
next_token = torch.tensor([[distrib_idx[0][0].item()] ]).to(device)
next_token = torch.tensor(1 - np.array(done)*1).view(-1,1).to(device) * next_token
generatd_tokens = torch.cat([generatd_tokens, next_token ], dim =-1)
if (done[0] == False and next_token[0][0].item() == model.config.eos_token_id) or (t == max_t - 1) :
done[0] = True
# If all sentences are generated, we exit the loop
if all(done) :
break
gen_sentences = tokenizer.batch_decode(generatd_tokens, skip_special_tokens=True)
print(gen_sentences)
输出:['Obwohl das ein etwas morbider Gedanke sein könnte, denke']
# Hugging Face Implementation
translated_greedy = model.generate(input_ids = encoded['input_ids'].to(device) ,max_length=max_t+2,do_sample = False)
translated_greedy_sen = tokenizer.batch_decode(translated_greedy, skip_special_tokens=True)
print(translated_greedy_sen)
输出: ['Das mag zwar ein etwas morbider Gedanke sein, aber ich denke']
提前致谢