我想使用 BoolTensor 索引在 Pytorch 中对多维张量进行切片。我希望索引张量保留索引为真的部分,而将索引为假的部分切掉。
我的代码就像
import torch
a = torch.zeros((5, 50, 5, 50))
tr_indices = torch.zeros((50), dtype=torch.bool)
tr_indices[1:50:2] = 1
val_indices = ~tr_indices
print(a[:, tr_indices].shape)
print(a[:, tr_indices, :, val_indices].shape)
我希望a[:, tr_indices, :, val_indices]
是形状[5, 25, 5, 25]
,但它返回[25, 5, 5]
。结果是
torch.Size([5, 25, 5, 50])
torch.Size([25, 5, 5])
我很困惑。谁能解释为什么?