尽管我尽了最大努力,但我还是无法运行 torch.jit.trace,遇到RuntimeError: Input, output and indices must be on the current device
我有一个(相当复杂的)模型,我已经把它放在 GPU 上,连同一组输入,也在 GPU 上。我可以验证所有输入张量和模型参数和缓冲区都在同一设备上:
(Pdb) {p.device for p in self.parameters()}
{device(type='cuda', index=0)}
(Pdb) {p.device for p in self.buffers()}
{device(type='cuda', index=0)}
(Pdb) in_ = (<several tensors here>)
(Pdb) {p.device for p in in_}
{device(type='cuda', index=0)}
(Pdb) torch.cuda.current_device()
0
我可以证明模型运行并且输出在正确的设备上:
(Pdb) self(*in_).device
device(type='cuda', index=0)
尽管如此,追踪还是失败了:
(Pdb) generator_script = torch.jit.trace(self, example_inputs=in_)
*** RuntimeError: Input, output and indices must be on the current device
- 我了解输入和输出,但是必须在同一设备上的这些“索引”是什么?
- 我没有考虑到的其他哪些因素可能导致跟踪失败?