5

查看代码片段:

import torch
x = torch.tensor([-1.], requires_grad=True)
y = torch.where(x > 0., x, torch.tensor([2.], requires_grad=True))
y.backward()
print(x.grad)

输出是tensor([0.]),但是

import torch
x = torch.tensor([-1.], requires_grad=True)
if x > 0.:
    y = x
else:
    y = torch.tensor([2.], requires_grad=True)
y.backward()
print(x.grad)

输出是None

我很困惑,为什么输出torch.whereis tensor([0.])

更新

import torch
a = torch.tensor([[1,2.], [3., 4]])
b = torch.tensor([-1., -1], requires_grad=True)
a[:,0] = b

(a[0, 0] * a[0, 1]).backward()
print(b.grad)

输出是tensor([2., 0.])。与没有(a[0, 0] * a[0, 1])任何关系b[1],但与 的梯度b[1]无关。0None

4

1 回答 1

4

基于跟踪的 AD,如 pytorch,通过tracking工作。您无法跟踪不是库拦截的函数调用的内容。通过使用这样的语句,andif之间没有连接,而 with和在表达式树中被链接。xywherexy

现在,对于差异:

  • 在第一个片段中,0是函数x ↦ x > 0 ? x : 2在该点的正确导数-1(因为负侧是常数)。
  • 正如我所说,在第二个片段中,与(在分支中)x没有任何关系。因此,given的导数是未定义的,表示为.yelseyxNone

(你甚至可以在 Python 中做这样的事情,但这需要更复杂的技术,比如源代码转换。我不认为 pytorch 可以做到。)

于 2020-04-13T09:13:47.617 回答