您可以让 numpy 处理迭代,即对其进行矢量化:
def local_maxima(xval, yval):
xval = np.asarray(xval)
yval = np.asarray(yval)
sort_idx = np.argsort(xval)
yval = yval[sort_idx]
gradient = np.diff(yval)
maxima = np.diff((gradient > 0).view(np.int8))
return np.concatenate((([0],) if gradient[0] < 0 else ()) +
(np.where(maxima == -1)[0] + 1,) +
(([len(yval)-1],) if gradient[-1] > 0 else ()))
编辑所以代码首先计算从每个点到 nex( gradient
) 的变化。下一步有点棘手......如果你这样做np.diff((gradient > 0)
,得到的布尔数组就是True
从增长(> 0
)变为不增长(<= 0
)的地方。通过将其设置为与布尔数组大小相同的有符号整数,您可以区分从增长到不增长 ( -1
) 到相反 ( +1
) 的转换。通过采用.view(np.int8)
与布尔数组相同的 dtype 大小的有符号整数类型,我们避免了复制数据,如果我们做更少的hacky会发生这种情况.astype(int)
. 剩下的就是处理第一个和最后一个点,并将所有点连接到一个数组中。我今天发现的一件事是,如果您在发送到的元组中包含一个空列表np.concatenate
,它会作为 dtype 的空数组出现np.float
,并且最终成为结果的 dtype,因此空元组的连接更复杂在上面的代码中。
有用:
In [2]: local_maxima(xval, yval)
Out[2]: array([ 1, 6, 10], dtype=int64)
并且相当快:
In [3]: xval = np.random.rand(10000)
In [4]: yval = np.random.rand(10000)
In [5]: local_maxima(xval, yval)
Out[5]: array([ 0, 2, 4, ..., 9991, 9995, 9998], dtype=int64)
In [6]: %timeit local_maxima(xval, yval)
1000 loops, best of 3: 1.16 ms per loop
此外,大部分时间是将数据从列表转换为数组并对其进行排序。如果您的数据已经排序并保存在数组中,您可能可以将上述性能提高 5 倍。