8

在HuggingFace进行的最后几层序列分类中,他们将 Transformer 输出的序列长度的第一个隐藏状态用于分类。

hidden_state = distilbert_output[0]  # (bs, seq_len, dim) <-- transformer output
pooled_output = hidden_state[:, 0]  # (bs, dim)           <-- first hidden state
pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)
pooled_output = nn.ReLU()(pooled_output)  # (bs, dim)
pooled_output = self.dropout(pooled_output)  # (bs, dim)
logits = self.classifier(pooled_output)  # (bs, dim)

取第一个隐藏状态而不是最后一个、平均甚至使用 Flatten 层有什么好处吗?

4

1 回答 1

6

是的,这与 BERT 的训练方式直接相关。具体来说,我鼓励您查看原始的 BERT 论文,其中作者介绍了[CLS]令牌的含义:

[CLS]是在每个输入示例前添加的特殊符号 [...]。

具体来说,它用于分类目的,因此是分类任务微调的第一个也是最简单的选择。您的相关代码片段正在做什么,基本上只是提取这个[CLS]令牌。

不幸的是,Huggingface 库的 DistilBERT 文档没有明确提及这一点,但您必须查看他们的BERT 文档,其中他们还强调了[CLS]令牌的一些问题,类似于您的担忧:

除了 MLM,BERT 还使用下一句预测 (NSP) 目标进行训练,使用 [CLS] 标记作为序列近似值。用户可以使用这个标记(用特殊标记构建的序列中的第一个标记)来获得序列预测而不是标记预测。但是,对序列进行平均可能会比使用 [CLS] 令牌产生更好的结果。

于 2020-02-20T09:03:53.683 回答