我试图通过实现牛顿求解x = cos(x)的方法来深入了解 PyTorch 的工作原理。这是一个有效的版本:
x = Variable(DoubleTensor([1]), requires_grad=True)
for i in range(5):
y = x - torch.cos(x)
y.backward()
x = Variable(x.data - y.data/x.grad.data, requires_grad=True)
print(x.data) # tensor([0.7390851332151607], dtype=torch.float64) (correct)
这段代码对我来说似乎不优雅(效率低下?),因为它在for
循环的每个步骤中都重新创建了整个计算图(对吗?)。我试图通过简单地更新每个变量持有的数据而不是重新创建它们来避免这种情况:
x = Variable(DoubleTensor([1]), requires_grad=True)
y = x - torch.cos(x)
y.backward(retain_graph=True)
for i in range(5):
x.data = x.data - y.data/x.grad.data
y.data = x.data - torch.cos(x.data)
y.backward(retain_graph=True)
print(x.data) # tensor([0.7417889255761136], dtype=torch.float64) (wrong)
似乎,对于DoubleTensor
s,我携带了足够多的精度来排除舍入误差。那么错误来自哪里?
可能相关:如果循环,上面的代码片段retain_graph=True
在每一步都没有设置标志的情况下中断。for
如果我在循环中省略它——但在第 3 行保留它——我得到的错误消息是:
RuntimeError: Trying backing through the graph a second time, but the buffers have been freed. 第一次向后调用时指定retain_graph=True 。这似乎证明我误解了什么......