6

我正在阅读对数损失和交叉熵,似乎有两种计算方法,基于以下等式。

在此处输入图像描述

第一个是以下

import numpy as np
from sklearn.metrics import log_loss


def cross_entropy(predictions, targets):
    N = predictions.shape[0]
    ce = -np.sum(targets * np.log(predictions)) / N
    return ce


predictions = np.array([[0.25,0.25,0.25,0.25],
                        [0.01,0.01,0.01,0.97]])
targets = np.array([[1,0,0,0],
                   [0,0,0,1]])

x = cross_entropy(predictions, targets)
print(log_loss(targets, predictions), 'our_answer:', ans)

上一个程序的输出是0.7083767843022996 our_answer: 0.71355817782,几乎是一样的。所以这不是问题。

上面的实现是上面等式的中间部分。

第二种方法基于上述等式的 RHS 部分。

res = 0
for act_row, pred_row in zip(targets, np.array(predictions)):
    for class_act, class_pred in zip(act_row, pred_row):
        res += - class_act * np.log(class_pred) - (1-class_act) * np.log(1-class_pred)

print(res/len(targets))

输出是1.1549753967602232,这不太一样。

我用 NumPy 尝试了相同的实现,但它也没有工作。我究竟做错了什么?

PS:我也很好奇,-y log (y_hat)在我看来这和- sigma(p_i * log( q_i))那怎么会有一个-(1-y) log(1-y_hat)部分一样。显然我误解了如何-y log (y_hat)计算。

4

1 回答 1

10

我无法重现您在第一部分中报告的结果的差异(您还引用了一个ans变量,您似乎没有定义它,我猜它是x):

import numpy as np
from sklearn.metrics import log_loss


def cross_entropy(predictions, targets):
    N = predictions.shape[0]
    ce = -np.sum(targets * np.log(predictions)) / N
    return ce

predictions = np.array([[0.25,0.25,0.25,0.25],
                        [0.01,0.01,0.01,0.97]])
targets = np.array([[1,0,0,0],
                   [0,0,0,1]])

结果:

cross_entropy(predictions, targets)
# 0.7083767843022996

log_loss(targets, predictions)
# 0.7083767843022996

log_loss(targets, predictions) == cross_entropy(predictions, targets)
# True

您的cross_entropy功能似乎工作正常。

关于第二部分:

显然我误解了如何-y log (y_hat)计算。

确实,更仔细地阅读您链接到的 fast.ai wiki,您会发现等式的 RHS 仅适用于二元分类(其中始终为 1y1-y将为零),此处并非如此 - 您有4 类多项式分类。所以,正确的公式是

res = 0
for act_row, pred_row in zip(targets, np.array(predictions)):
    for class_act, class_pred in zip(act_row, pred_row):
        res += - class_act * np.log(class_pred)

即丢弃 的减法(1-class_act) * np.log(1-class_pred)

结果:

res/len(targets)
# 0.7083767843022996

res/len(targets) == log_loss(targets, predictions)
# True

在更一般的层面上(二进制分类的日志丢失和准确性机制),您可能会发现这个答案很有用。

于 2018-03-25T09:22:20.600 回答