3

我有一个二维 NumPy ndarray。

array([[  0.,  20.,  -2.],
   [  2.,   1.,   0.],
   [  4.,   3.,  20.]])

如何获得最大元素的所有索引?所以我想作为输出数组([0,1],[2,2])。

4

1 回答 1

3

np.argwhere最大相等掩码上使用-

np.argwhere(a == a.max())

样品运行 -

In [552]: a   # Input array
Out[552]: 
array([[  0.,  20.,  -2.],
       [  2.,   1.,   0.],
       [  4.,   3.,  20.]])

In [553]: a == a.max() # Max equality mask
Out[553]: 
array([[False,  True, False],
       [False, False, False],
       [False, False,  True]], dtype=bool)

In [554]: np.argwhere(a == a.max()) # array of row, col indices of max-mask
Out[554]: 
array([[0, 1],
       [2, 2]])

如果您正在使用浮点数,您可能希望在此处使用一些容差。因此,考虑到这一点,您可以使用np.isclose具有一些默认绝对和相对容差值的值。这将取代前面的a == a.max()部分,就像这样 -

In [555]: np.isclose(a, a.max())
Out[555]: 
array([[False,  True, False],
       [False, False, False],
       [False, False,  True]], dtype=bool)
于 2016-11-30T20:43:54.300 回答