2

我有一个不平衡的数据集,我的任务是多标签分类

这是我减少损失的代码:

logits = inference(input)
xent = tf.nn.sigmoid_cross_entropy_with_logits(
        logits=logits, labels=labels, name='xent')
loss = tf.reduce_mean(xent, name='loss_op')

现在。我想weighted-loss用于我的分类,我该怎么做呢?我可以使用这个链接,并替换softmaxsigmoid

观点

我已阅读此链接,但我的情况不是二进制分类,在tensorflow_org中我认为它也适用于二进制分类。

4

1 回答 1

2

你可以使用tf.losses.compute_weighted_loss. 我建议阅读代码以确切了解此功能的工作原理,但您应该能够大致做到:

logits = inference(input)
xent = tf.nn.sigmoid_cross_entropy_with_logits(
    logits=logits, labels=labels, name='xent')
weighted_loss = tf.losses.compute_weighted_loss(xent, YOUR_WEIGHTS, name='weighted_loss_op')
于 2018-03-17T18:31:01.620 回答