我需要过滤一个数组以删除低于某个阈值的元素。我当前的代码是这样的:
threshold = 5
a = numpy.array(range(10)) # testing data
b = numpy.array(filter(lambda x: x >= threshold, a))
问题是这会创建一个临时列表,使用带有 lambda 函数的过滤器(慢)。
由于这是一个非常简单的操作,也许有一个 numpy 函数可以有效地完成它,但我一直无法找到它。
我认为实现这一点的另一种方法可能是对数组进行排序,找到阈值的索引并从该索引开始返回一个切片,但即使这对于小输入来说会更快(而且它不会被注意到),随着输入大小的增长,它的效率肯定会逐渐降低。
有任何想法吗?谢谢!
更新:我也进行了一些测量,当输入为 100.000.000 个条目时,排序+切片仍然比纯 python 过滤器快两倍。
In [321]: r = numpy.random.uniform(0, 1, 100000000)
In [322]: %timeit test1(r) # filter
1 loops, best of 3: 21.3 s per loop
In [323]: %timeit test2(r) # sort and slice
1 loops, best of 3: 11.1 s per loop
In [324]: %timeit test3(r) # boolean indexing
1 loops, best of 3: 1.26 s per loop