2

我正在尝试在我的自定义数据集上微调“RobertaForQuestionAnswering”,但我对它需要的输入参数感到困惑。这是示例代码。

>>> from transformers import RobertaTokenizer, RobertaForQuestionAnswering
>>> import torch

>>> tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
>>> model = RobertaForQuestionAnswering.from_pretrained('roberta-base')

>>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
>>> inputs = tokenizer(question, text, return_tensors='pt')
>>> start_positions = torch.tensor([1])
>>> end_positions = torch.tensor([3])

>>> outputs = model(**inputs, start_positions=start_positions, end_positions=end_positions)
>>> loss = outputs.loss
>>> start_scores = outputs.start_logits
>>> end_scores = outputs.end_logits

我无法理解模型中作为输入给出的变量start_positionsend_positions以及正在生成的变量start_scoresend_scores 。

4

1 回答 1

1

问答 ot 基本上是一个 DL 模型,它通过提取部分上下文(在您的情况下称为text)来创建答案。这意味着 QAbot 的目标是识别答案的开始结束


QAbot 的基本功能:

首先,问题和上下文的每个单词都被标记化。这意味着它(可能分为字符/子词,然后)转换为数字。它实际上取决于标记器的类型(这意味着它取决于您使用的模型,因为您将使用相同的标记器 - 这是您代码的第三行正在执行的操作)。我建议这个非常有用的指南

然后,将标记化question + text的内容传递到执行其内部操作的模型中。还记得我一开始说过模型会识别答案的startend吗?好吧,它是通过计算每个标记的question + text特定标记是答案开始的概率来实现的。这个概率是start_logits. 之后,对结束令牌执行相同的操作。

start_scores所以,这就是end_scores:pre-softmax 分数,每个标记分别是答案的开始和结束。


那么,什么是start_positionstop_position

如此处所述,它们是:

start_positions( torch. LongTensorof shape ( batch_size,), optional) - 用于计算标记分类损失的标记跨度开始的位置(索引)标签。位置被限制到序列的长度 ( sequence_length)。计算损失时不考虑序列之外的位置。

end_positions( torch. LongTensorof shape ( batch_size,), optional) – 用于计算标记分类损失的标记跨度末端的位置(索引)标签。位置被限制到序列的长度 ( sequence_length)。计算损失时不考虑序列之外的位置。


此外,您使用的模型(roberta-base参见HuggingFace 存储库RoBERTa 官方论文中的模型)尚未针对 QuestionAnswering 进行微调。它“只是”一个使用 MaskedLanguageModeling 训练的模型,这意味着该模型对英语有一个大致的了解,但不适合提问。您当然可以使用它,但它可能会给出非最佳结果。

我建议您使用相同的模型,在 QuestionAnswering: 上专门微调的版本中roberta-base-squad2,请参阅HuggingFace

实际上,您必须将加载模型和标记器的行替换为:

tokenizer = RobertaTokenizer.from_pretrained('roberta-base-squad2')
model = RobertaForQuestionAnswering.from_pretrained('roberta-base-squad2')

这将给出更准确的结果。

额外阅读:什么是微调以及它是如何工作的

于 2021-10-12T20:52:15.083 回答