我正在尝试对尺寸的 PyTorch 张量进行切片my_tensor
,s x b x c
以便沿第一个维度的切片根据indices
长度的张量而变化b
,效果如下:
my_tensor[0:indices, torch.arange(0, b, dtype=torch.long), :] = something
上面的代码不起作用并收到错误TypeError: tuple indices must be integers or slices, not tuple
。
我的目标是,例如,如果indices = torch.tensor([3, 5, 4])
那时:
my_tensor[0:3, 0, :] = something
my_tensor[0:5, 1, :] = something
my_tensor[0:4, 2, :] = something
我希望有一种张量的方式来做到这一点,所以我不必求助于 for 循环。此外,该方法需要与 TorchScript 兼容。非常感谢。