1

我在下面给出了一个 numpy 数组(2 个元素列表的列表)a,并且我有一个[30.94, 0.]要查找的 2 个元素的列表。

当我执行以下操作时,我没有得到想要的结果。为什么?

import numpy as np
a = np.array([[  5.73,   0.  ],
              [ 57.73,  10.  ],
              [ 57.73,  20.  ],
              [ 30.94,   0.  ],
              [ 30.94,  10.  ],
              [ 30.94,  20.  ],
              [  4.14,   0.  ],
              [  4.14,  10.  ]])

np.where(a==np.array([30.94, 0.]))

但我明白了

(array([0, 3, 3, 4, 5, 6]), array([1, 0, 1, 0, 0, 1]))

这不是真的。

4

2 回答 2

4

正如 Divakar 暗示的那样,a == np.array([30.94, 0.])这不是您所期望的。数组是广播的,比较是按元素进行的。结果如下:

array([[False,  True],
       [False, False],
       [False, False],
       [ True,  True],
       [ True, False],
       [ True, False],
       [False,  True],
       [False, False]], dtype=bool)

但是,我们可以得到我们想要的np.all

>>> np.all(a==np.array([30.94, 0.]), axis=-1)
array([False, False, False,  True, False, False, False, False], dtype=bool)
>>> np.where(_)
(array([3]),)

所以你可以看到第 3 行匹配,正如预期的那样。请注意,使用==浮点数的常见注意事项将适用于此处。

于 2016-12-22T20:14:37.790 回答
1

还有另一种解决方案,但请注意,这将比Dietrich 的解决方案慢一点,特别是对于大型阵列。

In [1]: cond = np.array([30.94, 0.])
In [2]: arr = np.array([[  5.73,   0.  ],
                       [ 57.73,  10.  ],
                       [ 57.73,  20.  ],
                       [ 30.94,   0.  ],
                       [ 30.94,  10.  ],
                       [ 30.94,  20.  ],
                       [  4.14,   0.  ],
                       [  4.14,  10.  ]])

In [3]: [idx for idx, el in enumerate(arr) if np.array_equal(el, cond)]
Out[3]: [3]
于 2016-12-23T00:14:47.967 回答