当我尝试randperm
使用 C++ PyTorch API 生成置换整数索引列表时,生成的张量具有元素类型CPUFloatType{10}
而不是整数类型:
int N_SAMPLES = 10;
torch::Tensor shuffled_indices = torch::randperm(N_SAMPLES);
cout << shuffled_indices << endl;
返回
9
3
8
6
2
5
4
7
1
0
[ CPUFloatType{10} ]
它不能用于张量的索引,因为元素类型是浮点数而不是整数类型。当尝试使用my_tensor.index(shuffled_indices)
我得到
terminate called after throwing an instance of 'c10::IndexError'
what(): tensors used as indices must be long, byte or bool tensors
环境:
- python-pytorch,Arch Linux 上的版本 1.6.0-2
- g++ (GCC) 10.1.0
为什么会这样?