我已经修改了 BERT-NER 模型的 PyTorch实现并添加了 CRF。我有以下课程在我的情况下可以正常工作。
class BertWithCRF(BertPreTrainedModel):
def __init__(self, config, labels, dropout=0.1):
super(BertWithCRF, self).__init__(config)
self.tagset_size = len(labels)
self.tag_to_ix = {k: v for v, k in enumerate(labels)}
self.bert = BertModel(config)
self.dropout = nn.Dropout(dropout)
self.classifier = nn.Linear(config.hidden_size, self.tagset_size)
self.apply(self.init_bert_weights)
self.transitions = nn.Parameter(
torch.zeros(self.tagset_size, self.tagset_size))
self.transitions.data[self.tag_to_ix[START_TAG], :] = -10000
self.transitions.data[:, self.tag_to_ix[STOP_TAG]] = -10000
def _batch_forward_alg(self, feats, mask):
assert mask is not None
# calculate in log domain
# feats is batch_size * len(sentence) * tagset_size
# initialize alpha with a Tensor with values all equal to -10000.
score = torch.Tensor(feats.size(0), self.tagset_size).fill_(-10000.)
score[:, self.tag_to_ix[START_TAG]] = 0.
if feats.is_cuda:
score = score.cuda()
mask = mask.float()
trans = self.transitions.unsqueeze(0) # [1, C, C]
for t in range(feats.size(1)): # recursion through the sequence
mask_t = mask[:, t].unsqueeze(1)
emit_t = feats[:, t].unsqueeze(2) # [B, C, 1]
score_t = score.unsqueeze(1) + emit_t + trans # [B, 1, C] -> [B, C, C]
score_t = batch_log_sum_exp(score_t) # [B, 1, C] -> [B, C, C]
score = score_t * mask_t + score * (1 - mask_t)
score = batch_log_sum_exp(score + self.transitions[self.tag_to_ix[STOP_TAG]])
return score # partition function
def _batch_score_sentence(self, feats, tags, mask):
assert mask is not None
score = torch.Tensor(feats.size(0)).fill_(0.)
if feats.is_cuda:
score = score.cuda()
feats = feats.unsqueeze(3)
mask = mask.float()
trans = self.transitions.unsqueeze(2)
add_start_tags = torch.empty(tags.size(0), 1).fill_(self.tag_to_ix[START_TAG]).type_as(tags)
tags = torch.cat([add_start_tags, tags], dim=-1)
for t in range(feats.size(1)): # recursion through the sequence
mask_t = mask[:, t]
emit_t = torch.cat([h[t, y[t + 1]] for h, y in zip(feats, tags)])
trans_t = torch.cat([trans[y[t + 1], y[t]] for y in tags])
score += (emit_t + trans_t) * mask_t
last_tag = tags.gather(1, mask.sum(1).long().unsqueeze(1)).squeeze(1)
score += self.transitions[self.tag_to_ix[STOP_TAG], last_tag]
return score
def _batch_viterbi_decode(self, feats, mask):
# initialize backpointers and viterbi variables in log space
bptr = torch.LongTensor()
score = torch.Tensor(feats.size(0), self.tagset_size).fill_(-10000.)
score[:, self.tag_to_ix[START_TAG]] = 0.
if feats.is_cuda:
score = score.cuda()
bptr = bptr.cuda()
mask = mask.float()
for t in range(feats.size(1)): # recursion through the sequence
mask_t = mask[:, t].unsqueeze(1)
score_t = score.unsqueeze(1) + self.transitions # [B, 1, C] -> [B, C, C]
score_t, bptr_t = score_t.max(2) # best previous scores and tags
score_t += feats[:, t] # plus emission scores
bptr = torch.cat((bptr, bptr_t.unsqueeze(1)), 1)
score = score_t * mask_t + score * (1 - mask_t)
score += self.transitions[self.tag_to_ix[STOP_TAG]]
best_score, best_tag = torch.max(score, 1)
# back-tracking
bptr = bptr.tolist()
best_path = [[i] for i in best_tag.tolist()]
for b in range(feats.size(0)):
x = best_tag[b] # best tag
y = int(mask[b].sum().item()) # no. of non-pad tokens
for bptr_t in reversed(bptr[b]):
x = bptr_t[x]
best_path[b].append(x)
best_path[b].pop()
best_path[b].reverse()
best_path = torch.LongTensor(best_path)
if feats.is_cuda:
best_path = best_path.cuda()
return best_path
def _get_bert_features(self, input_ids, token_type_ids, attention_mask):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
sequence_output = self.dropout(sequence_output)
bert_feats = self.classifier(sequence_output)
return bert_feats
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
bert_feats = self._get_bert_features(input_ids, token_type_ids, attention_mask)
if labels is not None:
forward_score = self._batch_forward_alg(bert_feats, attention_mask)
gold_score = self._batch_score_sentence(bert_feats, labels, attention_mask)
return (forward_score - gold_score).mean()
else:
tag_seq = self._batch_viterbi_decode(bert_feats, attention_mask)
return tag_seq