好吧,为那些 numpy speed-freaks 买了一个奇怪的。
我有这样的数据:
- Nx2 整数值数组
- 0 到 N-1 之间的每个整数出现两次
- 数据中会有一个或多个“循环”。
“循环”将是排序的行的子集,使得每一行与其上方的行共享一个元素,另一个元素与其下方的行共享。目标是找到产生闭环的数据的索引数组。
示例数据(单循环):
In: data = np.array([[0, 7],
[1, 8],
[2, 9],
[3, 0],
[4, 1],
[5, 2],
[6, 3],
[4, 7],
[8, 5],
[9, 6]])
示例解决方案:
In: ordered_indices = np.array([0, 7, 4, 1, 8, 5, 2, 9, 6, 3])
In: data[ordered_indices]
Out: array([[0, 7],
[4, 7],
[4, 1],
[1, 8],
[8, 5],
[5, 2],
[2, 9],
[9, 6],
[6, 3],
[3, 0]])
不保证行中元素的顺序;即,7 可以是它出现的两行中的第一个元素,或者是其中的第一个元素和另一个中的第二个元素。
数据量级为N=1000;带有循环的解决方案太慢了。
为方便起见,可以使用以下脚本生成典型数据。在这里,有序数据的索引遵循周期性模式,但在实际数据中并非如此。
生成样本数据:
import numpy as np
import sys
# parameters
N = 1000
M = 600
# initialize array
data = np.empty((N,2), dtype=np.int)
# populate first column
data[:,0] = np.arange(N)
# populate second column by shifting first column; create two loops within the data
inds1 = np.arange(0,M)[np.arange(-7,M-7)]
inds2 = np.arange(M,N)[np.arange(-9,N-M-9)]
data[:M,1] = data[inds1,0]
data[M:,1] = data[inds2,0]
# shuffle order of two entries in rows
map(np.random.shuffle, data)
我已经编写了一种可以得到正确结果的方法,但是速度很慢(在我老化的笔记本电脑上大约需要 0.5 秒):
基线解决方案:
def groupRows(rows):
# create a list of indices
ungrouped_rows = range(len(rows))
# initialize list of lists of indices
grouped_rows = []
# loop until there are no ungrouped rows
while 0 < len(ungrouped_rows):
# remove a row from the beginning of the list
row_index = ungrouped_rows.pop(0)
# start a new list of rows
grouped_rows.append([row_index])
# get the element at the start of the loop
stop_at = data[grouped_rows[-1][0],0]
# search target
look_for = data[grouped_rows[-1][0],1]
# continue until loop is closed
closed = False
while not closed:
# for every row left in the ungrouped list
for i, row_index in enumerate(ungrouped_rows):
# get two elements in the row being checked
b1,b2 = data[row_index]
# if row contains the current search target
if look_for in (b1,b2):
# add it to the current group
grouped_rows[-1].append(ungrouped_rows.pop(i))
# update search target
if look_for == b1:
look_for = b2
elif look_for == b2:
look_for = b1
# exit the loop through the ungrouped rows
break
# check if the loop is closed
if look_for == stop_at:
closed = True
return map(np.array, grouped_rows)
所以我的方法有效,但使用列表和两个嵌套循环;使用 numpy 更有效的方法必须有一种更巧妙的方法来做到这一点。有任何想法吗?