我想知道 Python 变量被覆盖的 PyTorch 张量是否仍保留在 PyTorch 的计算图中。
所以这是一个小例子,我有一个 RNN 模型,其中隐藏状态(和其他一些变量)在每次迭代后被重置,
backward()
稍后调用。
例子:
for i in range(5):
output = rnn_model(inputs[i])
loss += criterion(output, target[i])
## hidden states are overwritten with a zero vector
rnn_model.reset_hidden_states()
loss.backward()
所以我的问题是:
在调用之前覆盖隐藏状态是否有问题
backward()
?或者计算图是否将先前迭代的隐藏状态的必要信息保留在内存中以计算梯度?
编辑:如果有官方消息来源的声明会很棒。例如,声明所有与 CG 相关的变量都被保留 - 无论是否还有其他针对该变量的 python 引用。我假设图表本身中有一个引用阻止垃圾收集器删除它。但我想知道这是否真的如此。
提前致谢!