这是一个解决方案。我不确定是否可以对其进行矢量化。如果你想让它抵抗“浮动比较错误”,你应该修改is_less
和is_greater
. 整个算法只是一个二分搜索。
import numpy as np
#lexicographicaly compare two points - a and b
def is_less(a, b):
i = 0
while i<len(a):
if a[i]<b[i]:
return True
else:
if a[i]>b[i]:
return False
i+=1
return False
def is_greater(a, b):
i = 0
while i<len(a):
if a[i]>b[i]:
return True
else:
if a[i]<b[i]:
return False
i+=1
return False
def binary_search(a, x, lo=0, hi=None):
if hi is None:
hi = len(a)
while lo < hi:
mid = (lo+hi)//2
midval = a[mid]
if is_less(midval, x):
lo = mid+1
elif is_greater(midval, x):
hi = mid
else:
return mid
return -1
def lex_sort(v): #sort by 1 and 2 column respectively
#return v[np.lexsort((v[:,2],v[:,1]))]
order = range(1, v.shape[1])
return v[np.lexsort(tuple(v[:,i] for i in order[::-1]))]
def sort_and_index(arr):
ind = np.indices((len(arr),)).reshape((len(arr), 1))
arr = np.hstack([ind, arr]) # add an index column as first column
arr = lex_sort(arr)
arr_cut = arr[:,1:] # an array to do binary search in
arr_ind = arr[:,:1] # shuffled indices
return arr_ind, arr_cut
#lat1 = np.array(([1,2,3], [3,4,5], [5,6,7], [7,8,9])) # ~ 200000 rows
lat1 = np.arange(1,800001,1).reshape((200000,4))
#lat2 = np.array(([3,4,5], [5,6,7], [7,8,9], [1,2,3])) # same number of rows as time
lat2 = np.arange(101,800101,1).reshape((200000,4))
lat1_ind, lat1_cut = sort_and_index(lat1)
time_arr = np.zeros(200000)
import time
start = time.time()
for ii, elem in enumerate(lat2):
pos = binary_search(lat1_cut, elem)
if pos == -1:
#Not found
continue
pos = lat1_ind[pos][0]
#print "element in lat2 with index",ii,"has position",pos,"in lat1"
print time.time()-start
注释打印是您拥有 lat1 和 lat2 相应索引的地方。在 200000 行上工作 7 秒。