我想加快一个我经常使用的功能,我想使用 cython。但是,在尝试了我在文档中找到的所有可能的 cython 优化之后,cython 代码比 python+numpy 函数慢了大约 6 倍。令人失望!


def forward1(points, rotation, translation):
    '''points are in columns'''
    return np.dot(rotation, points - translation[:, np.newaxis])

import numpy as np
cimport numpy as np
cimport cython

cdef np.float64_t[:,:] forward2(np.float64_t[:,:] points, np.float64_t[:,:] rotation, np.float64_t[:] translation):
    '''points are in columns'''
    cdef unsigned int I, J
    I = points.shape[0]
    J = points.shape[1]
    cdef np.float64_t[:,:] tmp = np.empty((I, J), dtype=np.float64)
    cdef unsigned int i
    for i in range(J):
        tmp[0, i] = points[0, i] - translation[0]        
        tmp[1, i] = points[1, i] - translation[1]        
    cdef np.float64_t[:,:] result = np.dot(rotation, tmp)
    return result

def test_forward2(points, rotation, translation):
    import timeit
    cdef np.float64_t[:,:] points2 = points
    cdef np.float64_t[:,:] rotation2 = rotation
    cdef np.float64_t[:] translation2 = translation
    t = timeit.Timer(lambda: forward2(points2, rotation2, translation2))
    print min(t.repeat(3, 10))


t = timeit.Timer(lambda: forward1(points, rotation, translation))
print min(t.repeat(3, 10))

test_forward2(points, rotation, translation)

我可以对 cython 代码做些什么来让它更快吗?

如果无法在 cython 中加速 forward1,我可以希望使用 weave 加速吗?


只是为了记录,我试图加速该功能的另一件事是按 fortran 顺序传递点,因为我的点存储在列中并且其中有很多。我还将本地 tmp 定义为 fortran 顺序。我认为函数的减法部分应该更快,但 numpy.dot 似乎需要 C 顺序输出(无论如何要解决这个问题?),所以总的来说也没有加速。我还尝试转置点,以便减法部分在 C 顺序中更快,但似乎点积仍然是最昂贵的部分。

另外,我注意到 numpy.dot 不能使用 memoryviews 作为输出参数,即使它是 C 顺序,这是一个错误吗?


1 回答 1



Cython is great for speeding up cases where numpy often performs poorly (e.g. iterative algorithms where the iteration is written in python), but in this case, the inner loop is already being preformed by a BLAS library.

If you want to speed things up, the first place I'd look is what BLAS/LAPACK/ATLAS/etc libraries numpy is linked against. Using a "tuned" linear algebra library (e.g. ATLAS or Intel's MKL) will make a large (>10x in some cases) difference in cases like this.

To find out what you're currently using have a look at the output of numpy.show_config()

于 2012-10-13T21:40:21.567 回答