196

我在 numpy 中有一个一维数组,我想找到索引的位置,其中某个值超过了 numpy 数组中的值。

例如

aa = range(-10,10)

查找超出aa值的位置。5

4

8 回答 8

252

这有点快(看起来更好)

np.argmax(aa>5)

由于argmax将在第一个停止True(“如果多次出现最大值,则返回与第一次出现对应的索引。”)并且不保存另一个列表。

In [2]: N = 10000

In [3]: aa = np.arange(-N,N)

In [4]: timeit np.argmax(aa>N/2)
100000 loops, best of 3: 52.3 us per loop

In [5]: timeit np.where(aa>N/2)[0][0]
10000 loops, best of 3: 141 us per loop

In [6]: timeit np.nonzero(aa>N/2)[0][0]
10000 loops, best of 3: 142 us per loop
于 2013-04-26T19:42:17.433 回答
111

给定数组的排序内容,还有一种更快的方法:searchsorted

import time
N = 10000
aa = np.arange(-N,N)
%timeit np.searchsorted(aa, N/2)+1
%timeit np.argmax(aa>N/2)
%timeit np.where(aa>N/2)[0][0]
%timeit np.nonzero(aa>N/2)[0][0]

# Output
100000 loops, best of 3: 5.97 µs per loop
10000 loops, best of 3: 46.3 µs per loop
10000 loops, best of 3: 154 µs per loop
10000 loops, best of 3: 154 µs per loop
于 2014-07-30T09:09:54.393 回答
30

我对此也很感兴趣,并将所有建议的答案与perfplot进行了比较。(免责声明:我是 perfplot 的作者。)

如果您知道您正在查看的数组已经排序,那么

numpy.searchsorted(a, alpha)

是给你的。这是 O(log(n)) 操作,即速度几乎不依赖于数组的大小。你不能比这更快。

如果您对阵列一无所知,那么您不会出错

numpy.argmax(a > alpha)

已经排序:

在此处输入图像描述

未分类:

在此处输入图像描述

重现情节的代码:

import numpy
import perfplot


alpha = 0.5
numpy.random.seed(0)


def argmax(data):
    return numpy.argmax(data > alpha)


def where(data):
    return numpy.where(data > alpha)[0][0]


def nonzero(data):
    return numpy.nonzero(data > alpha)[0][0]


def searchsorted(data):
    return numpy.searchsorted(data, alpha)


perfplot.save(
    "out.png",
    # setup=numpy.random.rand,
    setup=lambda n: numpy.sort(numpy.random.rand(n)),
    kernels=[argmax, where, nonzero, searchsorted],
    n_range=[2 ** k for k in range(2, 23)],
    xlabel="len(array)",
)
于 2018-01-20T20:47:54.667 回答
18
In [34]: a=np.arange(-10,10)

In [35]: a
Out[35]:
array([-10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,
         3,   4,   5,   6,   7,   8,   9])

In [36]: np.where(a>5)
Out[36]: (array([16, 17, 18, 19]),)

In [37]: np.where(a>5)[0][0]
Out[37]: 16
于 2013-04-26T19:42:53.313 回答
14

元素之间具有恒定步长的数组

如果是一个range或任何其他线性增加的数组,您可以简单地以编程方式计算索引,根本不需要实际迭代数组:

def first_index_calculate_range_like(val, arr):
    if len(arr) == 0:
        raise ValueError('no value greater than {}'.format(val))
    elif len(arr) == 1:
        if arr[0] > val:
            return 0
        else:
            raise ValueError('no value greater than {}'.format(val))

    first_value = arr[0]
    step = arr[1] - first_value
    # For linearly decreasing arrays or constant arrays we only need to check
    # the first element, because if that does not satisfy the condition
    # no other element will.
    if step <= 0:
        if first_value > val:
            return 0
        else:
            raise ValueError('no value greater than {}'.format(val))

    calculated_position = (val - first_value) / step

    if calculated_position < 0:
        return 0
    elif calculated_position > len(arr) - 1:
        raise ValueError('no value greater than {}'.format(val))

    return int(calculated_position) + 1

一个人可能会改进一点。我已经确保它适用于一些示例数组和值,但这并不意味着其中不会有错误,特别是考虑到它使用浮点数......

>>> import numpy as np
>>> first_index_calculate_range_like(5, np.arange(-10, 10))
16
>>> np.arange(-10, 10)[16]  # double check
6

>>> first_index_calculate_range_like(4.8, np.arange(-10, 10))
15

鉴于它可以在没有任何迭代的情况下计算位置,它将是恒定时间 ( O(1)) 并且可能会击败所有其他提到的方法。但是,它需要在数组中保持一个恒定的步长,否则会产生错误的结果。

使用 numba 的一般解决方案

更通用的方法是使用 numba 函数:

@nb.njit
def first_index_numba(val, arr):
    for idx in range(len(arr)):
        if arr[idx] > val:
            return idx
    return -1

这适用于任何数组,但它必须遍历数组,所以在平均情况下它将是O(n)

>>> first_index_numba(4.8, np.arange(-10, 10))
15
>>> first_index_numba(5, np.arange(-10, 10))
16

基准

尽管 Nico Schlömer 已经提供了一些基准,但我认为包含我的新解决方案并测试不同的“值”可能很有用。

测试设置:

import numpy as np
import math
import numba as nb

def first_index_using_argmax(val, arr):
    return np.argmax(arr > val)

def first_index_using_where(val, arr):
    return np.where(arr > val)[0][0]

def first_index_using_nonzero(val, arr):
    return np.nonzero(arr > val)[0][0]

def first_index_using_searchsorted(val, arr):
    return np.searchsorted(arr, val) + 1

def first_index_using_min(val, arr):
    return np.min(np.where(arr > val))

def first_index_calculate_range_like(val, arr):
    if len(arr) == 0:
        raise ValueError('empty array')
    elif len(arr) == 1:
        if arr[0] > val:
            return 0
        else:
            raise ValueError('no value greater than {}'.format(val))

    first_value = arr[0]
    step = arr[1] - first_value
    if step <= 0:
        if first_value > val:
            return 0
        else:
            raise ValueError('no value greater than {}'.format(val))

    calculated_position = (val - first_value) / step

    if calculated_position < 0:
        return 0
    elif calculated_position > len(arr) - 1:
        raise ValueError('no value greater than {}'.format(val))

    return int(calculated_position) + 1

@nb.njit
def first_index_numba(val, arr):
    for idx in range(len(arr)):
        if arr[idx] > val:
            return idx
    return -1

funcs = [
    first_index_using_argmax, 
    first_index_using_min, 
    first_index_using_nonzero,
    first_index_calculate_range_like, 
    first_index_numba, 
    first_index_using_searchsorted, 
    first_index_using_where
]

from simple_benchmark import benchmark, MultiArgument

并且这些图是使用以下方法生成的:

%matplotlib notebook
b.plot()

项目在开头

b = benchmark(
    funcs,
    {2**i: MultiArgument([0, np.arange(2**i)]) for i in range(2, 20)},
    argument_name="array size")

在此处输入图像描述

numba 函数执行得最好,其次是 calculate-function 和 searchsorted 函数。其他解决方案的表现要差得多。

项目在最后

b = benchmark(
    funcs,
    {2**i: MultiArgument([2**i-2, np.arange(2**i)]) for i in range(2, 20)},
    argument_name="array size")

在此处输入图像描述

对于小数组,numba 函数的执行速度惊人地快,但是对于较大的数组,它的计算函数和 searchsorted 函数的性能要好得多。

项目在 sqrt(len)

b = benchmark(
    funcs,
    {2**i: MultiArgument([np.sqrt(2**i), np.arange(2**i)]) for i in range(2, 20)},
    argument_name="array size")

在此处输入图像描述

这更有趣。同样 numba 和 calculate 函数表现出色,但这实际上触发了最坏的 searchsorted 情况,在这种情况下确实不能正常工作。

没有值满足条件时的函数比较

另一个有趣的点是,如果没有应返回其索引的值,这些函数的行为方式:

arr = np.ones(100)
value = 2

for func in funcs:
    print(func.__name__)
    try:
        print('-->', func(value, arr))
    except Exception as e:
        print('-->', e)

有了这个结果:

first_index_using_argmax
--> 0
first_index_using_min
--> zero-size array to reduction operation minimum which has no identity
first_index_using_nonzero
--> index 0 is out of bounds for axis 0 with size 0
first_index_calculate_range_like
--> no value greater than 2
first_index_numba
--> -1
first_index_using_searchsorted
--> 101
first_index_using_where
--> index 0 is out of bounds for axis 0 with size 0

Searchsorted、argmax 和 numba 只是返回错误值。但是searchsortednumba返回一个不是数组有效索引的索引。

函数wheremin和抛出异常nonzerocalculate然而,实际上只有例外才能calculate说明任何有用的东西。

这意味着实际上必须将这些调用包装在一个适当的包装函数中,该函数捕获异常或无效返回值并进行适当处理,至少在您不确定该值是否可以在数组中的情况下。


注意:计算和searchsorted选项仅在特殊条件下有效。“计算”函数需要一个恒定的步骤,而 searchsorted 需要对数组进行排序。因此,这些在适当的情况下可能很有用,但不是解决此问题的通用解决方案。如果您正在处理已排序的Python 列表,您可能希望查看bisect模块,而不是使用 Numpys searchsorted。

于 2018-04-19T17:36:06.420 回答
4

我想提议

np.min(np.append(np.where(aa>5)[0],np.inf))

这将返回满足条件的最小索引,而如果条件从未满足则返回无穷大(并where返回一个空数组)。

于 2018-05-08T18:47:34.673 回答
1

我会去

i = np.min(np.where(V >= x))

其中V是向量(一维数组),x是值,i是结果索引。

于 2018-02-17T18:13:17.503 回答
0

您应该使用np.where而不是np.argmax. 即使没有找到值,后者也会返回位置 0,这不是您期望的索引。

>>> aa = np.array(range(-10,10))
>>> print(aa)
array([-10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,
         3,   4,   5,   6,   7,   8,   9])

如果满足条件,则返回索引数组。

>>> idx = np.where(aa > 5)[0]
>>> print(idx)
array([16, 17, 18, 19], dtype=int64)

否则,如果不满足,则返回一个空数组。

>>> not_found = len(np.where(aa > 20)[0])
>>> print(not_found)
array([], dtype=int64)

这种情况的反对点argmax是:越简单最好,如果解决方案不模棱两可。因此,要检查是否有某些东西落入这种状态,只需执行if len(np.where(aa > value_to_search)[0]) > 0.

于 2021-08-12T19:50:36.117 回答