1

我们对每个令牌的 bert 向量感兴趣。对于 bert 向量,我们指的是 berts 输出层中特定标记的词向量。所以我们想找出哪个令牌产生哪个伯特向量。我们编写了一些代码,但我们不确定它是否正确或如何测试它。

所以在代码中我们用bert处理一个句子。我们构建一个位置 id 列表并将它们交给模型。之后我们使用相同的位置 ID 将标记映射到输出层。然后有一些代码产生计算输入句子中每个向量的字符偏移量。

这是如何使用 position_ids 生成的正确方法吗

from transformers import BertModel, BertConfig, BertTokenizer
import torch

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

def sentence_to_vector(input_sentence):
    tokens_encoded = tokenizer.encode(input_sentence, add_special_tokens=True)
    input_ids = torch.tensor(tokens_encoded).unsqueeze(0)  # Batch size 1

    seq_length = input_ids.size(1)

    # code to construct position_ids from here: 
    # https://github.com/huggingface/transformers/blob/8da280ebbeca5ebd7561fd05af78c65df9161f92/pytorch_pretrained_bert/modeling.py#L188:L189
    position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)  
    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    outputs = model(input_ids, position_ids=position_ids)

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    # from the BertModel documentation (example at the bottom):
    # The last hidden-state is the first element of the output tuple
    # https://huggingface.co/transformers/model_doc/bert.html#transformers.BertModel

    #ttv = {}  # token to vector
    #for i in position_ids[0]:
    #    ttv[tokens[i]] = outputs[0][0][position_ids[0][i]]

    data = []
    last_offset = 0
    for i in range(0, len(position_ids[0])):
        token = tokens[position_ids[0][i]]
        vector = outputs[0][0][position_ids[0][i]]
        pos_begin = None
        pos_end = None
        if not token == "[CLS]" and not token == "[SEP]":
            pos_begin = input_sentence.find(token, last_offset)
            pos_end = pos_begin + len(token)
            last_offset = pos_end
        data.append({
            "token": token,
            "pos_begin": pos_begin,
            "pos_end": pos_end,
            "vector": vector
        })
    return data

input_sentence = "do the chicken dance!"
data = sentence_to_vector(input_sentence)

for token in data:
    print(token["token"] + "\t" + str(token["pos_begin"]) + "\t" + str(token["pos_end"]) + "\t" + str(token["vector"][0:3]) + "..." )
4

0 回答 0