我想在 a 的validation_epoch_end
方法中创建一个新的张量LightningModule
。从官方文档(第 48 页)中可以看出,我们应该避免直接.cuda()
或.to(device)
调用:
没有 .cuda() 或 .to() 调用。. . 闪电为你做这些。
并鼓励我们使用type_as
方法转移到正确的设备。
new_x = new_x.type_as(x.type())
但是,在一个步骤中,validation_epoch_end
我没有任何张量可以从(通过type_as
方法)以干净的方式复制设备。
我的问题是,如果我想用这种方法创建一个新的张量并将其转移到模型在哪里的设备上,我该怎么办?
我唯一能想到的就是在outputs
字典中找到一个张量,但感觉有点乱:
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
output = self(self.__test_input.type_as(avg_loss))
有什么干净的方法可以实现这一目标吗?