0

我在 2D 空间中有 N 个点(N 行,2 列),我想为(x_n,y_n)点集中的每个点找到最近的 k 个点,然后对它们进行排序。这是用于此目的的代码。我想加快以下代码:

def nearst_sort(x,y,k):
    N = len(x)    
    A = np.zeros((N,k))
    R = np.zeros((N,N))
    R = (x - x[np.newaxis].transpose())**2 + (y -y[np.newaxis].transpose())**2
    ix = np.argsort(R, kind='stable')
    ix = ix.transpose()
    A=ix[0:k,:].transpose()
    return A

我的样本数据如下:

x   y
0   0
0   0.5
0   1
0.5 0
0.5 0.5
0.5 1
1   0
1   0.5
1   1

我也尝试过 scipy.spatial.KDTree 中的函数,但没有得到好的结果。任何帮助,将不胜感激。

4

1 回答 1

0

对我来说,KDTree 要快得多:

>>> import numpy as np
>>> from scipy.spatial import cKDTree as KDTree
>>> from timeit import timeit
>>> 
>>> z = np.random.randn(1000, 2)
>>> k = 5
>>> 
>>> KDTree(z).query(z, k)
(array([[0.        , 0.21130505, 0.22903208, 0.32009477, 0.38000444],
       [0.        , 0.03969915, 0.06698214, 0.08423566, 0.10740011],
       [0.        , 0.04964421, 0.08194808, 0.11576068, 0.12022531],
       ...,
       [0.        , 0.00721785, 0.03346301, 0.03617199, 0.04193239],
       [0.        , 0.05147871, 0.05619545, 0.08028866, 0.08744349],
       [0.        , 0.03733766, 0.06359033, 0.06861222, 0.0698981 ]]), array([[  0, 391, 134, 462, 575],
       [  1,  87, 879, 846, 122],
       [  2, 793, 314, 564, 483],
       ...,
       [997, 390, 432, 165, 952],
       [998, 194, 457, 775, 629],
       [999, 158, 522, 862, 791]]))
>>> nearst_sort(*z.T, k)
array([[  0., 391., 134., 462., 575.],
       [  1.,  87., 879., 846., 122.],
       [  2., 793., 314., 564., 483.],
       ...,
       [997., 390., 432., 165., 952.],
       [998., 194., 457., 775., 629.],
       [999., 158., 522., 862., 791.]])
>>> timeit(lambda: KDTree(z).query(z, k), number=100)
0.12790076900273561
>>> timeit(lambda: nearst_sort(*z.T, k), number=100)
6.5285790269990684

这是 50 倍。不过,可能取决于示例。

于 2019-03-18T19:16:44.057 回答