4

我想知道 Python 变量被覆盖的 PyTorch 张量是否仍保留在 PyTorch 的计算图中。


所以这是一个小例子,我有一个 RNN 模型,其中隐藏状态(和其他一些变量)在每次迭代后被重置,backward()稍后调用。

例子:

for i in range(5):
   output = rnn_model(inputs[i])
   loss += criterion(output, target[i])
   ## hidden states are overwritten with a zero vector
   rnn_model.reset_hidden_states() 
loss.backward()

所以我的问题是:

  • 在调用之前覆盖隐藏状态是否有问题backward()

  • 或者计算图是否将先前迭代的隐藏状态的必要信息保留在内存中以计算梯度?

  • 编辑:如果有官方消息来源的声明会很棒。例如,声明所有与 CG 相关的变量都被保留 - 无论是否还有其他针对该变量的 python 引用。我假设图表本身中有一个引用阻止垃圾收集器删除它。但我想知道这是否真的如此。

提前致谢!

4

1 回答 1

0

我认为在倒退之前重置是可以的。该图保留了所需的信息。

class A (torch.nn.Module):
     def __init__(self):
         super().__init__()
         self.f1 = torch.nn.Linear(10,1)
     def forward(self, x):
         self.x = x 
         return torch.nn.functional.sigmoid (self.f1(self.x))
     def reset_x (self):
        self.x = torch.zeros(self.x.shape) 
net = A()
net.zero_grad()
X = torch.rand(10,10) 
loss = torch.nn.functional.binary_cross_entropy(net(X), torch.ones(10,1))
loss.backward()
params = list(net.parameters())
for i in params: 
    print(i.grad)
net.zero_grad() 

loss = torch.nn.functional.binary_cross_entropy(net(X), torch.ones(10,1))
net.reset_x()
print (net.x is X)
del X
loss.backward()     
params = list(net.parameters())
for i in params:
    print(i.grad)

在上面的代码中,我打印带有/不重置输入 x 的毕业生。梯度肯定取决于 x 并且重置它并不重要。因此,我认为图形保留了信息以进行反向操作。

于 2018-10-11T23:37:23.257 回答