3

这是一个非常简单的例子:

import torch

x = torch.tensor([1., 2., 3., 4., 5.], requires_grad=True)
y = torch.tensor([2., 2., 2., 2., 2.], requires_grad=True)
z = torch.tensor([1., 1., 0., 0., 0.], requires_grad=True)

s = torch.sum(x * y * z)
s.backward()

print(x.grad)

这将打印,

tensor([2., 2., 0., 0., 0.]),

因为,当然,对于 z 为零的条目,ds/dx 为零。

我的问题是:pytorch 是否智能并在达到零时停止计算?还是实际上做计算“ 2*5”,只是为了以后做“ 10 * 0 = 0”?

在这个简单的例子中,它并没有太大的区别,但是在我正在研究的(更大的)问题中,这会有所作为。

感谢您的任何意见。

4

1 回答 1

1

不,pytorch 不会在达到零时修剪任何后续计算。更糟糕的是,由于浮点运算的工作原理,所有后续乘以零的时间与任何常规乘法的时间大致相同。

在某些情况下,有一些方法可以解决它,例如,如果您想使用掩码损失,您可以掩码输出设置为零,或者将它们与梯度分离。

这个例子清楚地表明了区别:

def time_backward(do_detach):
    x = torch.tensor(torch.rand(100000000), requires_grad=True)
    y = torch.tensor(torch.rand(100000000), requires_grad=True)
    s2 = torch.sum(x * y)
    s1 = torch.sum(x * y)
    if do_detach:
        s2 = s2.detach()
    s = s1 + 0 * s2
    t = time.time()
    s.backward()
    print(time.time() - t)

time_backward(do_detach= False)
time_backward(do_detach= True)

输出:

0.502875089645
0.198422908783
于 2019-02-20T10:43:04.127 回答