如果您真的希望 numba 快速执行,则需要jit
在模式下使用该功能nopython
,否则 numba 可能会退回到较慢(并且可能非常慢)的对象模式。
但是,您的函数无法在 nopython 模式下编译(从 numba 版本 0.43.1 开始),这是因为:
- 的
dtype
论点np.empty
。np.float
只是 Python float
,将由 NumPy(但不是 numba)翻译为np.float_
. 如果你使用 numba,你必须使用它。
- numba 中缺少字符串支持。所以该
types[k] == 'float64'
行将无法编译。
第一个问题是微不足道的。关于第二个问题:与其尝试使字符串比较起作用,不如提供一个布尔数组。使用布尔数组并评估一个布尔值的完整性也比比较最多 7 个字符要快得多。特别是如果它在最内层循环中!
所以它可能看起来像这样:
import numpy as np
import numba as nb
@nb.njit
def pairwise_numba(X, is_float_type):
m = X.shape[0]
n = X.shape[1]
D = np.empty((int(m * (m - 1) / 2), 1), dtype=np.float64) # corrected dtype
ind = 0
for i in range(m):
for j in range(i+1, m):
d = 0.0
for k in range(n):
if is_float_type[k]:
tmp = X[i, k] - X[j, k]
d += tmp * tmp
else:
if X[i, k] != X[j, k]:
d += 1.
D[ind] = np.sqrt(d)
ind += 1
return D.reshape(1, -1)[0]
dists = pairwise_numba(vectors, types == 'float64') # pass in the boolean array
但是,如果将scipy.spatial.distances.pdist
浮点类型与 numba 逻辑结合起来计算不相等的类别,则可以简化逻辑:
from scipy.spatial.distance import pdist
@nb.njit
def categorial_sum(X):
m = X.shape[0]
n = X.shape[1]
D = np.zeros(int(m * (m - 1) / 2), dtype=np.float64) # corrected dtype
ind = 0
for i in range(m):
for j in range(i+1, m):
d = 0.0
for k in range(n):
if X[i, k] != X[j, k]:
d += 1.
D[ind] = d
ind += 1
return D
def pdist_with_categorial(vectors, types):
where_float_type = types == 'float64'
# calculate the squared distance of the float values
distances_squared = pdist(vectors[:, where_float_type], metric='sqeuclidean')
# sum the number of mismatched categorials and add that to the distances
# and then take the square root
return np.sqrt(distances_squared + categorial_sum(vectors[:, ~where_float_type]))
它不会明显更快,但它大大简化了 numba 函数中的逻辑。
然后,您还可以通过将平方距离传递给 numba 函数来避免额外的数组创建:
@nb.njit
def add_categorial_sum_and_sqrt(X, D):
m = X.shape[0]
n = X.shape[1]
ind = 0
for i in range(m):
for j in range(i+1, m):
d = 0.0
for k in range(n):
if X[i, k] != X[j, k]:
d += 1.
D[ind] = np.sqrt(D[ind] + d)
ind += 1
return D
def pdist_with_categorial(vectors, types):
where_float_type = types == 'float64'
distances_squared = pdist(vectors[:, where_float_type], metric='sqeuclidean')
return add_categorial_sum_and_sqrt(vectors[:, ~where_float_type], distances_squared)