1
from transformers import BertTokenizer, BertForMaskedLM
import torch

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')

input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, masked_lm_labels=input_ids)

loss, prediction_scores = outputs[:2] 

此代码来自拥抱脸变压器页面。https://huggingface.co/transformers/model_doc/bert.html#bertformaskedlm

我无法理解 中的masked_lm_labels=input_ids论点model。它是如何工作的?是不是表示通过时会自动屏蔽部分文字input_ids

4

1 回答 1

0

第一个参数是掩码输入,masked_lm_labels参数是所需的输出。

input_ids应该被屏蔽。通常,如何进行掩蔽取决于您。在最初的 BERT 中,他们选择了 15% 的代币和以下,或者

  • 使用[MASK]代币;或者
  • 使用随机令牌;或者
  • 保持原始令牌不变。

这会修改输入,因此您需要告诉模型原始的非屏蔽输入是什么,即masked_lm_labels参数。另请注意,您不想仅计算实际选择用于掩码的标记的损失。其余的标记应替换为 index -100

有关更多详细信息,请参阅文档

于 2020-04-28T07:36:01.117 回答