0

我想从二维张量的每一行中提取存储在另一个一维张量中的列。

import torch
test_tensor = tensor([1,-2,3], [-2,7,4]).float()
select_tensor = tensor([1,2])

所以在这个特定的例子中,我想获得位置 1 的第一行的元素(so -2)和位置 2 的第二行的元素(so 4)。我试过了:

test_tensor[:, select_tensor]

但这会为每一行选择位置 1 和 2 的元素。我怀疑这可能是我错过的非常简单的事情。

4

2 回答 2

1

如果您正在寻找带有索引的解决方案,您也需要建立索引axis=0,您可以这样做torch.arange

>>> test_tensor = torch.tensor([[1,-2,3], [-2,7,4]])
>>> select_tensor = torch.tensor([1,2])

>>> test_tensor[torch.arange(len(select_tensor)), select_tensor]
tensor([-2,  4])
于 2021-01-19T20:26:46.117 回答
1

您可以使用torch.gather

import torch
test_tensor = torch.tensor([[1,-2,3], [-2,7,4]]).float()
select_tensor = torch.tensor([1,2], dtype=torch.int64).view(-1,1) # number of dimension should match with the test tensor.
final_tensor = torch.gather(test_tensor, 1, select_tensor)
final_tensor

输出

tensor([[-2.],
        [ 4.]])

或者,用于torch.view展平输出张量:final_tensor.view(-1)会给你tensor([-2., 4.])

于 2021-01-19T20:26:53.807 回答