16

我正在研究文本分类问题(例如情感分析),我需要将文本字符串分类为五个类别之一。

我刚开始使用Huggingface Transformer包和带有 PyTorch 的 BERT。我需要的是一个顶部有一个 softmax 层的分类器,这样我就可以进行 5 路分类。令人困惑的是,Transformer 包中似乎有两个相关选项:BertForSequenceClassificationBertForMultipleChoice

我应该使用哪一个来完成我的 5 路分类任务?它们有哪些合适的用例?

BertForSequenceClassification的文档根本没有提到 softmax,尽管它确实提到了交叉熵。我不确定这个类是否仅用于 2 类分类(即逻辑回归)。

顶部带有序列分类/回归头的 Bert 模型转换器(池化输出顶部的线性层),例如用于 GLUE 任务。

  • 标签(torch.LongTensor of shape (batch_size,), optional, 默认为 None) -- 用于计算序列分类/回归损失的标签。索引应该在 [0, ..., config.num_labels - 1] 中。如果 config.num_labels == 1 计算回归损失(均方损失),如果 config.num_labels > 1 计算分类损失(交叉熵)。

BertForMultipleChoice的文档中提到了 softmax,但是标签的描述方式,听起来这个类是用于多标签分类(即多标签的二元分类)。

顶部具有多项选择分类头的 Bert 模型(池输出顶部的线性层和 softmax),例如用于 RocStories/SWAG 任务。

  • 标签(torch.LongTensor of shape (batch_size,), optional, 默认为 None) -- 用于计算多项选择分类损失的标签。索引应该在 [0, ..., num_choices] 中,其中 num_choices 是输入张量的第二维的大小。

感谢您的任何帮助。

4

1 回答 1

12

对此的答案在于(诚然非常简短)对任务内容的描述:

[ BertForMultipleChoice] [...],例如用于 RocStories/SWAG 任务。

在查看SWAG 的论文时,似乎该任务实际上是在学习从不同的选项中进行选择。这与您的“经典”分类任务形成对比,其中“选择”(即类别)不会因您的样本而变化,这正是BertForSequenceClassification它的用途。

通过更改配置中的参数,这两种变体实际上都可以用于任意数量的类(在 的情况下BertForSequenceClassification),分别用于选择(for )。但是,由于您似乎正在处理“经典分类”的案例,我建议使用该模型。BertForMultipleChoicelabelsBertForSequenceClassification

很快解决了缺少的 Softmax BertForSequenceClassification:由于分类任务可以计算与样本无关的类之间的损失(与多项选择不同,您的分布正在变化),这允许您使用交叉熵损失,它在反向传播步骤中考虑了 Softmax增加数值稳定性

于 2020-03-10T10:41:23.373 回答