这是一个非常简单的例子:
import torch
x = torch.tensor([1., 2., 3., 4., 5.], requires_grad=True)
y = torch.tensor([2., 2., 2., 2., 2.], requires_grad=True)
z = torch.tensor([1., 1., 0., 0., 0.], requires_grad=True)
s = torch.sum(x * y * z)
s.backward()
print(x.grad)
这将打印,
tensor([2., 2., 0., 0., 0.]),
因为,当然,对于 z 为零的条目,ds/dx 为零。
我的问题是:pytorch 是否智能并在达到零时停止计算?还是实际上做计算“ 2*5
”,只是为了以后做“ 10 * 0 = 0
”?
在这个简单的例子中,它并没有太大的区别,但是在我正在研究的(更大的)问题中,这会有所作为。
感谢您的任何意见。