如问题所述,我需要将张量附加到 Pytorch 计算图中的特定点。
我要做的是:在从所有小批量获取输出的同时,将它们累积在一个列表中,当一个时期结束时,计算平均值。然后,我需要根据平均值计算损失,因此反向传播必须考虑所有这些操作。
当训练数据不多时(无需分离和存储),我能够做到这一点。但是,当它变大时,这是不可能的。如果我不每次都分离输出张量,我的 GPU 内存就会用完,如果我分离,我会丢失计算图中输出张量的轨迹。看起来无论我有多少个 GPU,这都是不可能的,因为即使我分配了 4 个以上的 GPU,如果我在将它们保存到列表中之前不分离,PyTorch 只会使用前 4 个来存储输出张量。
非常感谢任何帮助。
谢谢。