9

我很难计算张量流中的交叉熵。特别是,我正在使用以下功能:

tf.nn.softmax_cross_entropy_with_logits()

使用看似简单的代码,我只能让它返回零

import tensorflow as tf
import numpy as np

sess = tf.InteractiveSession()

a = tf.placeholder(tf.float32, shape =[None, 1])
b = tf.placeholder(tf.float32, shape = [None, 1])
sess.run(tf.global_variables_initializer())
c = tf.nn.softmax_cross_entropy_with_logits(
    logits=b, labels=a
).eval(feed_dict={b:np.array([[0.45]]), a:np.array([[0.2]])})
print c

返回

0

我对交叉熵的理解如下:

H(p,q) = p(x)*log(q(x))

其中 p(x) 是事件 x 的真实概率,q(x) 是事件 x 的预测概率。

如果输入 p(x) 和 q(x) 的任意两个数字,则使用

0<p(x)<1 AND 0<q(x)<1

应该有一个非零交叉熵。我期待我错误地使用了 tensorflow。提前感谢您的帮助。

4

3 回答 3

19

除了 Don 的答案 (+1),mrry 写的这个答案可能会让您感兴趣,因为它给出了计算 TensorFlow 中交叉熵的公式:

另一种写法:

xent = tf.nn.softmax_cross_entropy_with_logits(logits, labels)

...将会:

softmax = tf.nn.softmax(logits)
xent = -tf.reduce_sum(labels * tf.log(softmax), 1)

然而,这种替代方案将 (i) 数值稳定性较差(因为 softmax 可能计算更大的值)和 (ii) 效率较低(因为在反向传播中会发生一些冗余计算)。对于实际使用,我们建议您使用 tf.nn.softmax_cross_entropy_with_logits().

于 2017-03-01T01:58:26.930 回答
14

就像他们说的那样,没有“softmax”就不能拼写“softmax_cross_entropy_with_logits”。的 Softmax[0.45][1],并且log(1)0

测量类别互斥(每个条目恰好属于一个类别)的离散分类任务中的概率误差。例如,每张 CIFAR-10 图像都标有一个且只有一个标签:图像可以是狗或卡车,但不能同时是两者。

注意: 虽然类是互斥的,但它们的概率不一定是互斥的。所需要的只是每一行labels都是一个有效的概率分布。如果不是,则梯度的计算将不正确。

如果使用独占labels(其中一次只有一个类为真),请参阅sparse_softmax_cross_entropy_with_logits

警告:此操作需要未缩放的 logits,因为它在内部执行softmax onlogits以提高效率。不要使用 的输出调用此操作softmax,因为它会产生不正确的结果。

logits并且labels必须具有相同的形状[batch_size, num_classes] 和相同的 dtype(或者float16float32float64)。

于 2017-03-01T01:49:38.460 回答
2

这是 Tensorflow 2.0 中的一个实现,以防其他人(我可能)将来需要它。

@tf.function
def cross_entropy(x, y, epsilon = 1e-9):
    return -2 * tf.reduce_mean(y * tf.math.log(x + epsilon), -1) / tf.math.log(2.)

x = tf.constant([
    [1.0,0],
    [0.5,0.5],
    [.75,.25]
    ]
,dtype=tf.float32)

with tf.GradientTape() as tape:
    tape.watch(x)
    y = entropy(x, x)

tf.print(y)
tf.print(tape.gradient(y, x))

输出

[-0 1 0.811278105]
[[-1.44269502 29.8973541]
 [-0.442695022 -0.442695022]
 [-1.02765751 0.557305]]
于 2020-09-08T04:20:29.487 回答