其他答案很好,这里分享一个简单的实现forward/backward
,不管损失函数。
在下图中,它是backward
对 softmax 的简要推导。第二个等式依赖于损失函数,不是我们实现的一部分。

backward
通过手动毕业检查验证。
import numpy as np
class Softmax:
def forward(self, x):
mx = np.max(x, axis=1, keepdims=True)
x = x - mx # log-sum-exp trick
e = np.exp(x)
probs = e / np.sum(np.exp(x), axis=1, keepdims=True)
return probs
def backward(self, x, probs, bp_err):
dim = x.shape[1]
output = np.empty(x.shape)
for j in range(dim):
d_prob_over_xj = - (probs * probs[:,[j]]) # i.e. prob_k * prob_j, no matter k==j or not
d_prob_over_xj[:,j] += probs[:,j] # i.e. when k==j, +prob_j
output[:,j] = np.sum(bp_err * d_prob_over_xj, axis=1)
return output
def compute_manual_grads(x, pred_fn):
eps = 1e-3
batch_size, dim = x.shape
grads = np.empty(x.shape)
for i in range(batch_size):
for j in range(dim):
x[i,j] += eps
y1 = pred_fn(x)
x[i,j] -= 2*eps
y2 = pred_fn(x)
grads[i,j] = (y1 - y2) / (2*eps)
x[i,j] += eps
return grads
def loss_fn(probs, ys, loss_type):
batch_size = probs.shape[0]
# dummy mse
if loss_type=="mse":
loss = np.sum((np.take_along_axis(probs, ys.reshape(-1,1), axis=1) - 1)**2) / batch_size
values = 2 * (np.take_along_axis(probs, ys.reshape(-1,1), axis=1) - 1) / batch_size
# cross ent
if loss_type=="xent":
loss = - np.sum( np.take_along_axis(np.log(probs), ys.reshape(-1,1), axis=1) ) / batch_size
values = -1 / np.take_along_axis(probs, ys.reshape(-1,1), axis=1) / batch_size
err = np.zeros(probs.shape)
np.put_along_axis(err, ys.reshape(-1,1), values, axis=1)
return loss, err
if __name__ == "__main__":
batch_size = 10
dim = 5
x = np.random.rand(batch_size, dim)
ys = np.random.randint(0, dim, batch_size)
for loss_type in ["mse", "xent"]:
S = Softmax()
probs = S.forward(x)
loss, bp_err = loss_fn(probs, ys, loss_type)
grads = S.backward(x, probs, bp_err)
def pred_fn(x, ys):
pred = S.forward(x)
loss, err = loss_fn(pred, ys, loss_type)
return loss
manual_grads = compute_manual_grads(x, lambda x: pred_fn(x, ys))
# compare both grads
print(f"loss_type = {loss_type}, grad diff = {np.sum((grads - manual_grads)**2) / batch_size}")