有两个数组,a 和索引。
a 的形状:(g,N),表示有 g 个组,所有样本都有 N 个。
索引的形状:(q,g),表示有 q 个类,每个类都包含不同的索引,供 g 组访问 a 的值。
例如,
a = [[1 3 7 8]
[2 4 5 6]] # shape:(2,4), 2 groups with 4 samples
indices = [[0 1]
[2 2]] # shape:(2,2), 2 class' with indices to access a for the two groups.
我尝试使用np.take(a, indices, axis=1)并获得
result = [[[1 3]
[7 7]]
[[2 4]
[5 5]]]
但这不是我想要的。我想要得到的结果是:
result = [[1,4]
[7,5]]
因为
indices[0] = [0,1] # class 0's indices for the two groups
a[0,0] = 1
a[1,1] = 4
indices[1] = [2,2] # class 1's indices for the two groups
a[0,2] = 7
a[1,2] = 5
有人可以帮忙吗?谢谢!