0

我是 numba 的用户,谁能告诉我为什么 numpy 数组的切片这么慢,这是一个例子:

def pairwise_python2(X):

    n_samples = X.shape[0]

    result = np.zeros((n_samples, n_samples), dtype=X.dtype)

    for i in xrange(X.shape[0]):

        for j in xrange(X.shape[0]):

            result[i, j] = np.sqrt(np.sum((X[i, :] - X[j, :]) ** 2))

    return result

%timeit pairwise_python2(X)

1 个循环,最好的 3 个:每个循环 18.2 秒

from numba import double

from numba.decorators import jit, autojit

pairwise_numba = autojit(pairwise_python)

%timeit pairwise_numba(X)

1 个循环,最好的 3 个:每个循环 13.9 秒

jit和cpython版本似乎没有区别,我错了吗?

4

2 回答 2

1

您正在计时 numpy 内存分配。X[i,:] - X[j,:] 生成一个形状为 (n_samples, n_samples) 的新矩阵,平方运算也是如此。请尝试以下方法:

def pairwise_python2(X):
    n_samples = X.shape[0]
    result = np.empty((n_samples, n_samples), dtype=X.dtype)
    temp = np.empty((n_samples,), dtype=X.dtype)
    for i in xrange(n_samples):
        slice = X[i,:]
        for j in xrange(n_samples):
            result[i,j] = np.sqrt(np.sum(np.power(np.subtract(slice,X[j,:],temp),2.0,temp)))
    return result

Numba 并没有为此增加很多,因为您在 numpy 中执行所有操作(尽管它会加速循环迭代,这在您的计时函数中可以看到)。

于 2013-12-09T20:30:24.460 回答
1

新版本numba支持numpy数组切片和np.sqrt()函数。所以,这个问题可以结束了。

于 2015-05-25T05:02:54.640 回答