1

我想使用 BoolTensor 索引在 Pytorch 中对多维张量进行切片。我希望索引张量保留索引为真的部分,而将索引为假的部分切掉。

我的代码就像

import torch
a = torch.zeros((5, 50, 5, 50))

tr_indices = torch.zeros((50), dtype=torch.bool)
tr_indices[1:50:2] = 1
val_indices = ~tr_indices

print(a[:, tr_indices].shape)
print(a[:, tr_indices, :, val_indices].shape)

我希望a[:, tr_indices, :, val_indices]是形状[5, 25, 5, 25],但它返回[25, 5, 5]。结果是

torch.Size([5, 25, 5, 50])
torch.Size([25, 5, 5])

我很困惑。谁能解释为什么?

4

1 回答 1

1

PyTorch从 Numpy继承其高级索引行为。像这样切片两次应该可以达到您想要的输出:

a[:, tr_indices][..., val_indices]
于 2021-05-26T16:17:21.033 回答