1

当我尝试randperm使用 C++ PyTorch API 生成置换整数索引列表时,生成的张量具有元素类型CPUFloatType{10}而不是整数类型:

int N_SAMPLES = 10;               
torch::Tensor shuffled_indices = torch::randperm(N_SAMPLES); 
cout << shuffled_indices << endl;

返回

 9
 3
 8
 6
 2
 5
 4
 7
 1
 0
[ CPUFloatType{10} ]

它不能用于张量的索引,因为元素类型是浮点数而不是整数类型。当尝试使用my_tensor.index(shuffled_indices)我得到

terminate called after throwing an instance of 'c10::IndexError'
  what():  tensors used as indices must be long, byte or bool tensors

环境:

  • python-pytorch,Arch Linux 上的版本 1.6.0-2
  • g++ (GCC) 10.1.0

为什么会这样?

4

1 回答 1

2

这是因为您使用 torch 创建的任何张量的默认类型始终是float. 如果需要,您必须使用TensorOptions参数 struct 指定它:

int N_SAMPLES = 10;               
torch::Tensor shuffled_indices = torch::randperm(N_SAMPLES, torch::TensorOptions().dtype(at::kLong));
cout << shuffled_indices.dtype() << endl;
>>> long
于 2020-08-20T10:50:01.990 回答