该DTAIDistance
包可用于查找k
输入查询的最佳匹配。但不能用于多维输入查询。此外,我想k
在一次运行中找到许多输入查询的最佳匹配。
我修改了这个DTAIDistance
函数,使它可以用于搜索多维多查询的子序列。我使用njit
with parallel 来加快处理速度,即 p_calc 函数将 numba-parallel 应用于每个输入查询。但我发现与只是简单地逐一循环输入查询(即 calc 函数)相比,并行计算似乎并没有加快计算速度。
import time
from tqdm import tqdm
from numba import njit, prange
import numpy as np
inf = np.inf
argmin=np.argmin
@njit(fastmath=True, nogil=True, error_model="numpy", cache=True, parallel=False)
def p_calc(d, dtw, s1, s2, r, c, psi_1b, psi_1e, psi_2b, psi_2e, window, max_step, max_dist, penalty, psi_neg):
n_series = s1.shape[1]
ndim = s1.shape[2]
# s1 = np.ascontiguousarray(s1)#.shape
# s2 = np.ascontiguousarray(s2)#.shape
# dtw = np.full((n_series,r + 1, c + 1), np.inf,dtype=s1.dtype) # cmath.inf
# d = np.full((n_series), np.inf,dtype=s1.dtype) # cmath.inf
for i in range(psi_2b + 1):
dtw[:, 0, i] = 0
for i in range(psi_1b + 1):
dtw[:, i, 0] = 0
for nn in prange(n_series):
print('im alive...')
i0 = 1
i1 = 0
sc = 0
ec = 0
smaller_found = False
ec_next = 0
for i in range(r):
i0 = i
i1 = i + 1
j_start = max(0, i - max(0, r - c) - window + 1)
j_end = min(c, i + max(0, c - r) + window)
if sc > j_start:
j_start = sc
smaller_found = False
ec_next = i
for j in range(j_start, j_end):
val = 0
tmp = ((s1[i, nn] - s2[j]) ** 2)
# tmp = (np.abs(s1[i, nn] - s2[j, 0]))
for nd in range(ndim):
val += tmp[nd]
d[nn] = val
# d = np.sum(np.abs(s1[i] - s2[j]) ) # multi-d
if max_step is not None and d[nn] > max_step:
continue
# print(i, j + 1 - skip, j - skipp, j + 1 - skipp, j - skip)
dtw[nn, i1, j + 1] = d[nn] + min(dtw[nn, i0, j],
dtw[nn, i0, j + 1] + penalty,
dtw[nn, i1, j] + penalty)
# dtw[i + 1, j + 1 - skip] = d + min(dtw[i + 1, j + 1 - skip], dtw[i + 1, j - skip])
if dtw[nn, i1, j + 1] > max_dist:
if not smaller_found:
sc = j + 1
if j >= ec:
break
else:
smaller_found = True
ec_next = j + 1
ec = ec_next
# Decide which d to return
dtw[nn] = np.sqrt(dtw[nn])
if psi_1e == 0 and psi_2e == 0:
d[nn] = dtw[nn, i1, min(c, c + window - 1)]
else:
ir = i1
ic = min(c, c + window - 1)
if psi_1e != 0:
vr = dtw[nn, ir:max(0, ir - psi_1e - 1):-1, ic]
mir = np.argmin(vr)
vr_mir = vr[mir]
else:
mir = ir
vr_mir = inf
if psi_2e != 0:
vc = dtw[nn, ir, ic:max(0, ic - psi_2e - 1):-1]
mic = np.argmin(vc)
vc_mic = vc[mic]
else:
mic = ic
vc_mic = inf
if vr_mir < vc_mic:
if psi_neg:
dtw[nn, ir:ir - mir:-1, ic] = -1
d[nn] = vr_mir
else:
if psi_neg:
dtw[nn, ir, ic:ic - mic:-1] = -1
d[nn] = vc_mic
if max_dist and d[nn] ** 2 > max_dist:
# if max_dist and d[nn] > max_dist:
d[nn] = inf
return d, dtw
@njit(fastmath=True, nogil=True) # Set "nopython" mode for best performance, equivalent to @njit
def calc(s1, s2, r, c, psi_1b, psi_1e, psi_2b, psi_2e, window, max_step, max_dist, penalty, psi_neg):
dtw = np.full((r + 1, c + 1), np.inf) # cmath.inf
for i in range(psi_2b + 1):
dtw[0, i] = 0
for i in range(psi_1b + 1):
dtw[i, 0] = 0
i0 = 1
i1 = 0
sc = 0
ec = 0
smaller_found = False
ec_next = 0
for i in range(r):
i0 = i
i1 = i + 1
j_start = max(0, i - max(0, r - c) - window + 1)
j_end = min(c, i + max(0, c - r) + window)
if sc > j_start:
j_start = sc
smaller_found = False
ec_next = i
for j in range(j_start, j_end):
# d = (s1[i] - s2[j]) ** 2# 1-d
d = np.sum((s1[i] - s2[j]) ** 2) # multi-d
# d = np.sum(np.abs(s1[i] - s2[j]) ) # multi-d
if max_step is not None and d > max_step:
continue
dtw[i1, j + 1] = d + min(dtw[i0, j],
dtw[i0, j + 1] + penalty,
dtw[i1, j] + penalty)
if dtw[i1, j + 1] > max_dist:
if not smaller_found:
sc = j + 1
if j >= ec:
break
else:
smaller_found = True
ec_next = j + 1
ec = ec_next
# Decide which d to return
dtw = np.sqrt(dtw)
if psi_1e == 0 and psi_2e == 0:
d = dtw[i1, min(c, c + window - 1)]
else:
ir = i1
ic = min(c, c + window - 1)
if psi_1e != 0:
vr = dtw[ir:max(0, ir - psi_1e - 1):-1, ic]
mir = argmin(vr)
vr_mir = vr[mir]
else:
mir = ir
vr_mir = inf
if psi_2e != 0:
vc = dtw[ir, ic:max(0, ic - psi_2e - 1):-1]
mic = argmin(vc)
vc_mic = vc[mic]
else:
mic = ic
vc_mic = inf
if vr_mir < vc_mic:
if psi_neg:
dtw[ir:ir - mir:-1, ic] = -1
d = vr_mir
else:
if psi_neg:
dtw[ir, ic:ic - mic:-1] = -1
d = vc_mic
if max_dist and d * d > max_dist:
d = inf
return d, dtw
mydtype = np.float32
series1 = np.random.random((16, 30, 2)).astype(mydtype)
series2 = np.random.random((100000, 2)).astype(mydtype)
n_series = series1.shape[1]
r = series1.shape[0]
c = series2.shape[0]
dtw = np.full((n_series, r + 1, c + 1), np.inf, dtype=mydtype) # cmath.inf
d = np.full((n_series), np.inf, dtype=mydtype) # cmath.inf
time1 = time.time()
d, dtw1 = p_calc(d, dtw, series1, series2, series1.shape[0], series2.shape[0], 0, 0,
series2.shape[0], series2.shape[0], series2.shape[0], np.inf, np.inf, 0.01, False)
print(time.time() - time1)
time1 = time.time()
for ii in tqdm(range(series1.shape[1])):
d, dtw1 = calc( series1[:, ii, :], series2, series1.shape[0], series2.shape[0], 0, 0,
series2.shape[0], series2.shape[0], series2.shape[0], np.inf, np.inf, 0.01, False)
print(time.time() - time1)# this one is faster
如何加速 calc 函数或 p_calc 函数,以便计算多维多查询的动态时间规整路径?
感谢您的回答,然后我修改了代码以进行简化。我删除了 np.sum 部分并使用循环,我可以获得另一个加速。对进一步加速有什么建议吗?
import time
from numba import njit, prange
import numpy as np
inf = np.inf
argmin=np.argmin
@njit(fastmath=True, nogil=True, error_model="numpy", cache=False, parallel=True)
def p_calc(d, dtw, s1, s2, r, c, psi_1b, psi_1e, psi_2b, psi_2e, window, max_step, max_dist, penalty, psi_neg):
n_series = s1.shape[1]
ndim = s1.shape[2]
for nn in prange(n_series):
for i in range(r):
j_start = 0
j_end = c
for j in range(j_start, j_end):
val = 0
# tmp = ((s1[i, nn] - s2[j]) ** 2)
# tmp = (np.abs(s1[i, nn] - s2[j, 0]))
for nd in range(ndim):
tmp = ((s1[i, nn,nd] - s2[j,nd]) ** 2)
val += tmp
d[nn] = val
return d, dtw
@njit(fastmath=True, nogil=True) # Set "nopython" mode for best performance, equivalent to @njit
def calc(dtw,s1, s2, r, c, psi_1b, psi_1e, psi_2b, psi_2e, window, max_step, max_dist, penalty, psi_neg):
ndim = s1.shape[-1]
for i in range(r):
j_start = 0
j_end = c
for j in range(j_start, j_end):
d = 0
for kk in range(ndim):
d += (s1[i, kk] - s2[j, kk]) ** 2
return d, dtw
mydtype = np.float32
series1 = np.random.random((16, 300, 2)).astype(mydtype)
series2 = np.random.random((1000000, 2)).astype(mydtype)
n_series = series1.shape[1]
r = series1.shape[0]
c = series2.shape[0]
dtw = np.full((n_series, r + 1, c + 1), np.inf, dtype=mydtype) # cmath.inf
d = np.full((n_series), np.inf, dtype=mydtype) # cmath.inf
time1 = time.time()
# assert 1==2
# dtw[:,series2.shape[0]]
d1, dtw1 = p_calc(d, dtw, series1, series2, series1.shape[0], series2.shape[0], 0, 0, series2.shape[0], series2.shape[0], series2.shape[0], np.inf, np.inf, 0.01, False)
print(time.time() - time1)
# assert 1==2
time1 = time.time()
dtw = np.full(( r + 1, c + 1), np.inf, dtype=mydtype) # cmath.inf
for ii in (range(series1.shape[1])):
d2, dtw2 = calc( dtw,series1[:, ii, :], series2, series1.shape[0], series2.shape[0], 0, 0,
series2.shape[0], series2.shape[0], series2.shape[0], np.inf, np.inf, 0.01, False)
print(time.time() - time1)# this one is faster
np.allclose(dtw1[-1],dtw2)
np.allclose(d1[-1],d2)
编辑:
我发现以下代码的性能如果使用pass
或break
. 我不明白为什么?
@njit(fastmath=True, nogil=True)
def kbest_matches(matching,k=4000):
ki = 0
while ki < k:
best_idx =np.argmin(matching)# np.argmin(np.arange(10000000))#
if best_idx == 0 :
# pass
break
ki += 1
return 0
ss= np.random.random((1575822,))
time1 = time.time()
pp = kbest_matches(ss)
print(time.time() - time1)