我想从二维张量的每一行中提取存储在另一个一维张量中的列。
import torch
test_tensor = tensor([1,-2,3], [-2,7,4]).float()
select_tensor = tensor([1,2])
所以在这个特定的例子中,我想获得位置 1 的第一行的元素(so -2)和位置 2 的第二行的元素(so 4)。我试过了:
test_tensor[:, select_tensor]
但这会为每一行选择位置 1 和 2 的元素。我怀疑这可能是我错过的非常简单的事情。