您的代码不是有效的 Python,我认为您混淆了您的形状......除此之外,如果您有足够的内存,您可以通过使用广播来矢量化距离计算来摆脱循环。如果你有这些数据:
n_set_points = 100
n_test_points = 10000
n_dims = 60
set_points = np.random.rand(n_set_points, n_dims)
test_points = np.random.rand(n_test_points, n_dims)
那么这是最直接的计算:
# deltas.shape = (n_set_points, n_test_point, n_dims)
deltas = (set_points[:, np.newaxis, :] -
test_points[np.newaxis, ...])
# dist[j, k] holds the squared distance between the
# j-th set_point and the k-th test point
dist = np.sum(deltas*deltas, axis=-1)
# nearest[j] is the index of the set_point closest to
# each test_point, has shape (n_test_points,)
nearest = np.argmin(dist, axis=0)
交易破坏者是您是否可以存储deltas
在内存中:它可以是一个巨大的数组。如果这样做,则可以通过使用更神秘但更有效的距离计算来获得一些性能:
dist = np.einsum('jkd,jkd->jk', deltas, deltas)
如果deltas
太大,请将您的 test_points 分成可管理的块,然后循环遍历这些块,例如:
def nearest_neighbor(set_pts, test_pts, chunk_size):
n_test_points = len(test_pts)
ret = np.empty((n_test_points), dtype=np.intp)
for chunk_start in xrange(0, n_test_points ,chunk_size):
deltas = (set_pts[:, np.newaxis, :] -
test_pts[np.newaxis,
chunk_start:chunk_start + chunk_size, :])
dist = np.einsum('jkd,jkd->jk', deltas,deltas)
ret[chunk_start:chunk_start + chunk_size] = np.argmin(dist, axis=0)
return ret
%timeit nearest_neighbor(set_points, test_points, 1)
1 loops, best of 3: 283 ms per loop
%timeit nearest_neighbor(set_points, test_points, 10)
1 loops, best of 3: 175 ms per loop
%timeit nearest_neighbor(set_points, test_points, 100)
1 loops, best of 3: 384 ms per loop
%timeit nearest_neighbor(set_points, test_points, 1000)
1 loops, best of 3: 365 ms per loop
%timeit nearest_neighbor(set_points, test_points, 10000)
1 loops, best of 3: 374 ms per loop
因此,通过进行部分矢量化可以获得一些性能。