我正在尝试使用 TensorFlow 实现多标签分类(即,每个输出模式可以有许多活动单元)。该问题具有不平衡的类别(即,标签分布中的零比零多得多,这使得标签模式非常稀疏)。
解决问题的最佳方法应该是使用该tf.nn.weighted_cross_entropy_with_logits
功能。但是,我收到此运行时错误:
ValueError: Tensor conversion requested dtype uint8 for Tensor with dtype float32
我不明白这里有什么问题。作为损失函数的输入,我传递了标签张量、logits 张量和正类权重,这是一个常数:
positive_class_weight = 10
loss = tf.nn.weighted_cross_entropy_with_logits(targets=labels, logits=logits, pos_weight=positive_class_weight)
关于如何解决这个问题的任何提示?如果我只是将相同的标签和 logits 张量传递给tf.losses.sigmoid_cross_entropy
损失函数,那么一切正常(在 Tensorflow 运行正常的意义上,但当然训练预测总是为零)。
在此处查看相关问题。