我们对每个令牌的 bert 向量感兴趣。对于 bert 向量,我们指的是 berts 输出层中特定标记的词向量。所以我们想找出哪个令牌产生哪个伯特向量。我们编写了一些代码,但我们不确定它是否正确或如何测试它。
所以在代码中我们用bert处理一个句子。我们构建一个位置 id 列表并将它们交给模型。之后我们使用相同的位置 ID 将标记映射到输出层。然后有一些代码产生计算输入句子中每个向量的字符偏移量。
这是如何使用 position_ids 生成的正确方法吗
from transformers import BertModel, BertConfig, BertTokenizer
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
def sentence_to_vector(input_sentence):
tokens_encoded = tokenizer.encode(input_sentence, add_special_tokens=True)
input_ids = torch.tensor(tokens_encoded).unsqueeze(0) # Batch size 1
seq_length = input_ids.size(1)
# code to construct position_ids from here:
# https://github.com/huggingface/transformers/blob/8da280ebbeca5ebd7561fd05af78c65df9161f92/pytorch_pretrained_bert/modeling.py#L188:L189
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
outputs = model(input_ids, position_ids=position_ids)
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
# from the BertModel documentation (example at the bottom):
# The last hidden-state is the first element of the output tuple
# https://huggingface.co/transformers/model_doc/bert.html#transformers.BertModel
#ttv = {} # token to vector
#for i in position_ids[0]:
# ttv[tokens[i]] = outputs[0][0][position_ids[0][i]]
data = []
last_offset = 0
for i in range(0, len(position_ids[0])):
token = tokens[position_ids[0][i]]
vector = outputs[0][0][position_ids[0][i]]
pos_begin = None
pos_end = None
if not token == "[CLS]" and not token == "[SEP]":
pos_begin = input_sentence.find(token, last_offset)
pos_end = pos_begin + len(token)
last_offset = pos_end
data.append({
"token": token,
"pos_begin": pos_begin,
"pos_end": pos_end,
"vector": vector
})
return data
input_sentence = "do the chicken dance!"
data = sentence_to_vector(input_sentence)
for token in data:
print(token["token"] + "\t" + str(token["pos_begin"]) + "\t" + str(token["pos_end"]) + "\t" + str(token["vector"][0:3]) + "..." )