运行代码片段(PyTorch 1.7.1;Python 3.8)后,
import numpy as np
import torch
def batch_matrix(vector_pairs, factor=2):
baselen = len(vector_pairs[0]) // factor
split_batch = []
for j in range(factor):
for i in range(factor):
start_j = j * baselen
end_j = (j+1) * baselen if j != factor - 1 else None
start_i = i * baselen
end_i = (i+1) * baselen if i != factor - 1 else None
mini_pairs = vector_pairs[start_j:end_j, start_i:end_i, :]
split_batch.append(mini_pairs)
return split_batch
def concat_matrix(vectors_):
vectors = vectors_.clone()
seq_len, dim_vec = vectors.shape
project_x = vectors.repeat((1, 1, seq_len)).reshape(seq_len, seq_len, dim_vec)
project_y = project_x.permute(1, 0, 2)
matrix = torch.cat((project_x, project_y), dim=-1)
matrix_ = matrix.clone()
return matrix_
if __name__ == "__main__":
vector_list = []
for i in range(10):
vector_list.append(torch.randn((5,), requires_grad=True))
vectors = torch.stack(vector_list, dim=0)
pmatrix = concat_matrix(vectors)
factor = np.ceil(vectors.shape[0]/6).astype(int)
batched_feats = batch_matrix(pmatrix, factor=factor)
for i in batched_feats:
i = i + 5
print(i.shape)
summed = torch.sum(i)
summed.backward()
我得到如下输出和错误:
torch.Size([5, 5, 10])
torch.Size([5, 5, 10])
Traceback (most recent call last):
File "/home/user/PycharmProjects/project/run.py", line 43, in <module>
summed.backward()
File "/home/user/anaconda3/envs/diff/lib/python3.8/site-packages/torch/tensor.py", line 221, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/user/anaconda3/envs/diff/lib/python3.8/site-packages/torch/autograd/__init__.py", line 130, in backward
Variable._execution_engine.run_backward(
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.
我已阅读有关该问题的所有现有帖子,但自己无法解决。传入retain_graph=Truebackward() 修复了提供的代码段中的问题,但是,代码段只是大型网络的过度简化版本,其中retain_graph=True将错误更改为以下内容:
RuntimeError:梯度计算所需的变量之一已被inplace操作修改:[torch.FloatTensor [3000, 512]],即TBackward的输出0,版本3;而是预期的版本 2。提示:启用异常检测以查找未能计算其梯度的操作,使用 torch.autograd.set_detect_anomaly(True)。
我尝试设置torch.autograd.set_detect_anomaly(True)和确定故障点,但我尝试的所有方法都失败了并且错误仍然存在。
我怀疑如果我能理解当前情况下的错误原因,那么它将帮助我在实际代码库中解决这个错误。
因此,我想了解为什么它对 中backward()的前两个张量工作正常batched_feats,而对第三个张量却失败了?如果有人可以帮助我查看已释放的中间结果的重用,我将不胜感激。
非常感谢!