请仔细阅读文档backward()
以更好地理解它。
默认情况下,pytorch 期望backward()
为网络的最后一个输出调用 - 损失函数。损失函数总是输出一个标量,因此,所有其他变量/参数的标量损失的梯度是明确定义的(使用链式法则)。
因此,默认情况下,backward()
在标量张量上调用并且不需要参数。
例如:
a = torch.tensor([[1,2,3],[4,5,6]], dtype=torch.float, requires_grad=True)
for i in range(2):
for j in range(3):
out = a[i,j] * a[i,j]
out.backward()
print(a.grad)
产量
tensor([[ 2., 4., 6.],
[ 8., 10., 12.]])
正如预期的那样:d(a^2)/da = 2a
。
但是,当您调用backward
2×3out
张量(不再是标量函数)时,您希望a.grad
是什么?你实际上需要一个 2×3×2×3 的输出:(d out[i,j] / d a[k,l]
!)
Pytorch 不支持这种非标量函数导数。相反,pytorch 假设out
只是一个中间张量,并且在“上游”某处有一个标量损失函数,通过链式法则提供d loss/ d out[i,j]
. 这个“上游”梯度的大小为 2×3,这实际上是您backward
在这种情况下提供的参数:out.backward(g)
where g_ij = d loss/ d out_ij
。
然后通过链式法则计算梯度d loss / d a[i,j] = (d loss/d out[i,j]) * (d out[i,j] / d a[i,j])
由于您提供a
了“上游”渐变,因此您得到了
a.grad[i,j] = 2 * a[i,j] * a[i,j]
如果您要提供“上游”渐变为全部
out.backward(torch.ones(2,3))
print(a.grad)
产量
tensor([[ 2., 4., 6.],
[ 8., 10., 12.]])
正如预期的那样。
这一切都在链式法则中。