0

有两个数组,a 和索引。

a 的形状:(g,N),表示有 g 个组,所有样本都有 N 个。

索引的形状:(q,g),表示有 q 个类,每个类都包含不同的索引,供 g 组访问 a 的值。

例如,

a = [[1 3 7 8]
     [2 4 5 6]] # shape:(2,4), 2 groups with 4 samples 
indices = [[0 1]
           [2 2]] # shape:(2,2), 2 class' with indices to access a for the two groups.

我尝试使用np.take(a, indices, axis=1)并获得

result = [[[1 3]
           [7 7]]

          [[2 4]
          [5 5]]]

但这不是我想要的。我想要得到的结果是:

result = [[1,4]
          [7,5]]

因为

indices[0] = [0,1] # class 0's indices for the two groups

a[0,0] = 1 
a[1,1] = 4

indices[1] = [2,2] # class 1's indices for the two groups
a[0,2] = 7
a[1,2] = 5

有人可以帮忙吗?谢谢!

4

1 回答 1

2

使用take_along_axis

np.take_along_axis(a.T,indices,0)
# array([[1, 4],
#        [7, 5]])
于 2020-07-04T02:31:07.223 回答