18

我试图backpropagation在一个简单的 3 层神经网络中理解MNIST.

输入层带有weightsbias。标签是MNIST这样的,它是一个10类向量。

第二层是一个linear tranform。第三层是将softmax activation输出作为概率。

Backpropagation计算每一步的导数并将其称为梯度。

以前的图层将globalprevious渐变附加到local gradient. 我无法local gradient计算softmax

网上的一些资源对softmax及其派生词进行了解释,甚至给出了softmax本身的代码示例

def softmax(x):
    """Compute the softmax of vector x."""
    exps = np.exp(x)
    return exps / np.sum(exps)

导数是关于 wheni = j和 when来解释的i != j。这是我想出的一个简单的代码片段,希望能验证我的理解:

def softmax(self, x):
    """Compute the softmax of vector x."""
    exps = np.exp(x)
    return exps / np.sum(exps)

def forward(self):
    # self.input is a vector of length 10
    # and is the output of 
    # (w * x) + b
    self.value = self.softmax(self.input)

def backward(self):
    for i in range(len(self.value)):
        for j in range(len(self.input)):
            if i == j:
                self.gradient[i] = self.value[i] * (1-self.input[i))
            else: 
                 self.gradient[i] = -self.value[i]*self.input[j]

然后self.gradientlocal gradientwhich 是一个向量。这个对吗?有没有更好的方法来写这个?

4

3 回答 3

25

我假设您有一个 3 层神经网络,其中W1,b1与从输入层到隐藏层的线性变换相关联,并且W2,b2与从隐藏层到输出层的线性变换相关联。Z1是隐藏层和Z2输出层的输入向量。a1表示隐藏层和a2输出层的输出。a2是您的预测输出。delta3并且delta2是误差(反向传播),您可以看到损失函数相对于模型参数的梯度。

在此处输入图像描述 在此处输入图像描述

这是 3 层 NN(输入层,只有一个隐藏层和一个输出层)的一般场景。您可以按照上面描述的过程来计算应该很容易计算的梯度!由于这篇文章的另一个答案已经指出了您代码中的问题,因此我不再重复。

于 2016-11-13T17:45:42.007 回答
11

np.exp不稳定,因为它有 Inf。所以你应该减去最大值x

def softmax(x):
    """Compute the softmax of vector x."""
    exps = np.exp(x - x.max())
    return exps / np.sum(exps)

如果x是矩阵,请检查此笔记本中的 softmax 函数。

于 2016-11-14T14:22:57.430 回答
11

正如我所说,你有n^2偏导数。

如果你做数学,你会发现你应该有dSM[i]/dx[k]SM[i] * (dx[i]/dx[k] - SM[i])

if i == j:
    self.gradient[i,j] = self.value[i] * (1-self.value[i])
else: 
    self.gradient[i,j] = -self.value[i] * self.value[j]

代替

if i == j:
    self.gradient[i] = self.value[i] * (1-self.input[i])
else: 
     self.gradient[i] = -self.value[i]*self.input[j]

顺便说一句,这可以像这样更简洁地计算(矢量化):

SM = self.value.reshape((-1,1))
jac = np.diagflat(self.value) - np.dot(SM, SM.T)
于 2016-11-13T17:44:16.707 回答