根据语言模型获得令牌的概率相对容易,如下面的片段所示。您可以获取模型的输出,将自己限制在掩码标记的输出中,然后在输出向量中找到您请求的标记的概率。然而,这只适用于单标记词,例如,它们本身就在标记器的词汇表中。当词汇表中不存在一个词时,分词器会将其分块成它确实知道的部分(参见示例的底部)。但是由于输入句子只包含一个被掩蔽的位置,而请求的 token 有更多的 token,我们如何才能得到它的概率呢?最终,我正在寻找一种不管单词有多少子词单元都可以工作的解决方案。
在下面的代码中,我添加了许多注释来解释发生了什么,以及打印出打印语句的给定输出。您会看到预测诸如“爱”和“恨”之类的标记很简单,因为它们在标记器的词汇表中。但是,'reprimand' 不是,所以它不能在单个掩码位置被预测——它由三个子词单元组成。那么我们如何预测蒙面位置的“谴责”呢?
from transformers import BertTokenizer, BertForMaskedLM
import torch
# init model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval()
# init softmax to get probabilities later on
sm = torch.nn.Softmax(dim=0)
torch.set_grad_enabled(False)
# set sentence with MASK token, convert to token_ids
sentence = f"I {tokenizer.mask_token} you"
token_ids = tokenizer.encode(sentence, return_tensors='pt')
print(token_ids)
# tensor([[ 101, 1045, 103, 2017, 102]])
# get the position of the masked token
masked_position = (token_ids.squeeze() == tokenizer.mask_token_id).nonzero().item()
# forward
output = model(token_ids)
last_hidden_state = output[0].squeeze(0)
# only get output for masked token
# output is the size of the vocabulary
mask_hidden_state = last_hidden_state[masked_position]
# convert to probabilities (softmax)
# giving a probability for each item in the vocabulary
probs = sm(mask_hidden_state)
# get probability of token 'hate'
hate_id = tokenizer.convert_tokens_to_ids('hate')
print('hate probability', probs[hate_id].item())
# hate probability 0.008057191967964172
# get probability of token 'love'
love_id = tokenizer.convert_tokens_to_ids('love')
print('love probability', probs[love_id].item())
# love probability 0.6704086065292358
# get probability of token 'reprimand' (?)
reprimand_id = tokenizer.convert_tokens_to_ids('reprimand')
# reprimand is not in the vocabulary, so it needs to be split into subword units
print(tokenizer.convert_ids_to_tokens(reprimand_id))
# [UNK]
reprimand_id = tokenizer.encode('reprimand', add_special_tokens=False)
print(tokenizer.convert_ids_to_tokens(reprimand_id))
# ['rep', '##rim', '##and']
# but how do we now get the probability of a multi-token word in a single-token position?