4

我有一个多标签分类问题,我试图用 Pytorch 中的 CNN 解决这个问题。我有 80,000 个训练示例和 7900 个课程;每个示例可以同时属于多个类,每个示例的平均类数为 130。

问题是我的数据集非常不平衡。对于某些课程,我只有大约 900 个示例,大约 1%。对于“过度代表”的课程,我有大约 12000 个示例(15%)。当我训练模型时,我使用来自pytorch 的BCEWithLogitsLoss和正权重参数。我计算权重的方式与文档中描述的相同:负例数除以正例数。

结果,我的模型几乎高估了每一个类……无论是小类还是大类,我得到的预测几乎是真实标签的两倍。而我的 AUPRC 只有 0.18。尽管它比完全没有加权要好得多,因为在这种情况下,模型将所有内容都预测为零。

所以我的问题是,如何提高性能?还有什么我可以做的吗?我尝试了不同的批量采样技术(对少数类进行过采样),但它们似乎不起作用。

4

2 回答 2

6

我会建议其中一种策略

焦点损失


Tsung-Yi Lin、Priya Goyal、Ross Girshick、Kaiming He 和 Piotr Dollar Focal Loss for Dense Object Detection (ICCV 2017)介绍了一种通过调整损失函数来处理不平衡训练数据的非常有趣的方法。
他们建议修改二元交叉熵损失,以减少易于分类示例的损失和梯度,同时“将努力”集中在模型出现严重错误的示例上。

硬负挖掘

另一种流行的方法是做“硬负挖掘”;也就是说,只为部分训练样本传播梯度——“硬”样本。
参见,例如:
Abhinav Shrivastava、Abhinav Gupta 和 Ross Girshick 使用在线困难示例挖掘训练基于区域的目标检测器(CVPR 2016)

于 2019-10-03T06:13:43.933 回答
0

@Shai 提供了在深度学习时代开发的两种策略。我想为您提供一些额外的传统机器学习选项:过采样欠采样

它们的主要思想是在开始训练之前通过采样产生更平衡的数据集。请注意,您可能会面临一些问题,例如丢失数据多样性(欠采样)和过度拟合训练数据(过采样),但这可能是一个很好的起点。

有关更多信息,请参阅wiki 链接

于 2019-10-03T09:42:33.683 回答