0

BERT我正在编写一个使用线性层和softmax顶部层进行预训练的问答系统。当遵循网上可用的模板时,一个示例的标签通常只包含一个answer_start_index和一个answer_end_index。例如,从Huggingface实例化SQUADFeatures对象时开始:

```
self.input_ids = input_ids
self.attention_mask = attention_mask
self.token_type_ids = token_type_ids
self.cls_index = cls_index
self.p_mask = p_mask

self.example_index = example_index
self.unique_id = unique_id
self.paragraph_len = paragraph_len
self.token_is_max_context = token_is_max_context
self.tokens = tokens
self.token_to_orig_map = token_to_orig_map

self.start_position = start_position
self.end_position = end_position
self.is_impossible = is_impossible
self.qas_id = qas_id
```

但是,在我自己的数据集中,我有一些示例,其中在上下文中的多个位置找到了答案词,即可能有几个正确的跨度构成了答案。

我的问题是我不知道如何管理这样的例子?在网络上可用的模板中,标签通常位于列表中,例如:

  • [start_example1,start_example2,start_example3]
  • [end_example1,end_example2,end_example3]

就我而言,这可能看起来像:

  • [start_example1, [start_example2_1, start_example2_2], start_example3]
  • 当然也一样

换句话说,我没有每个示例包含一个标签的列表,而是包含单个标签或示例“标签”列表的列表,即由列表组成的列表。

遵循其他模板时,该过程的下一步是:

```
input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
token_type_ids = torch.cat(token_type_ids, dim=0)
span_starts = torch(span_starts) #Something like this
span_ends = torch(span_ends) #Something like this
```

但是,这当然(?)会引发错误,因为我的 span_start 列表和 span_end 列表不仅包含单个项目,而且有时包含列表中的列表。

任何人都知道如何解决这个问题?我应该只使用只有一个跨度构成上下文中答案的示例吗?

如果我解决了火炬错误,那么损失的反向传播/评估/计算仍然有效吗?

谢谢你!/乙

4

1 回答 1

0

你检查过代码吗

from transformers import BertTokenizer, BertForQuestionAnswering
import torch

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
encoding = tokenizer.encode_plus(question, text)
input_ids, token_type_ids = encoding["input_ids"], encoding["token_type_ids"]
start_scores, end_scores = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([token_type_ids]))

all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
answer = ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1])

assert answer == "a nice puppet"

我不确定这是否是最好的方法,但您可以检查而不是 argmax 来使用topk,并检查这是否对应于正确的答案。

t = torch.LongTensor([0,1,2,3,4,5,6,7,8,9])
t
_, indices = t.topk(4)
indices#([9, 8, 7, 6])
于 2020-05-27T22:34:23.007 回答