10

TLDR:在 cython 中,为什么(或何时?)迭代 numpy 数组比迭代 python 列表更快?

一般来说:我以前使用过 Cython,并且能够比幼稚的 python impl' 获得巨大的加速,但是,弄清楚究竟需要做什么似乎并非易事。

考虑以下 3 个 sum() 函数的实现。它们驻留在一个名为“cy”的 cython 文件中(显然,有 np.sum(),但这不是我的意思..)

天真的蟒蛇:

def sum_naive(A):
   s = 0
   for a in A:
       s += a
   return s

Cython 带有一个需要 python 列表的函数:

def sum_list(A):
    cdef unsigned long s = 0
    for a in A:
        s += a
    return s

Cython 的函数需要一个 numpy 数组。

def sum_np(np.ndarray[np.int64_t, ndim=1] A):
    cdef unsigned long s = 0
    for a in A:
        s += a
    return s

我希望就运行时间而言,sum_np < sum_list < sum_naive,但是,以下脚本恰恰相反(为了完整起见,我添加了 np.sum() )

N = 1000000
v_np = np.array(range(N))
v_list = range(N)

%timeit cy.sum_naive(v_list)
%timeit cy.sum_naive(v_np)
%timeit cy.sum_list(v_list)
%timeit cy.sum_np(v_np)
%timeit v_np.sum()

结果:

In [18]: %timeit cyMatching.sum_naive(v_list)
100 loops, best of 3: 18.7 ms per loop

In [19]: %timeit cyMatching.sum_naive(v_np)
1 loops, best of 3: 389 ms per loop

In [20]: %timeit cyMatching.sum_list(v_list)
10 loops, best of 3: 82.9 ms per loop

In [21]: %timeit cyMatching.sum_np(v_np)
1 loops, best of 3: 1.14 s per loop

In [22]: %timeit v_np.sum()
1000 loops, best of 3: 659 us per loop

这是怎么回事?为什么 cython+numpy 慢?

PS
我确实使用
#cython: boundscheck=False
#cython: wraparound=False

4

2 回答 2

11

有一种更好的方法可以在 cython 中实现这一点,至少在我的机器上可以做到这一点,np.sum因为它避免了类型检查和 numpy 在处理任意数组时通常必须做的其他事情:

#cython.wraparound=False
#cython.boundscheck=False
cimport numpy as np

def sum_np(np.ndarray[np.int64_t, ndim=1] A):
    cdef unsigned long s = 0
    for a in A:
        s += a
    return s

def sum_np2(np.int64_t[::1] A):
    cdef:
        unsigned long s = 0
        size_t k

    for k in range(A.shape[0]):
        s += A[k]

    return s

然后是时间:

N = 1000000
v_np = np.array(range(N))
v_list = range(N)

%timeit sum(v_list)
%timeit sum_naive(v_list)
%timeit np.sum(v_np)
%timeit sum_np(v_np)
%timeit sum_np2(v_np)
10 loops, best of 3: 19.5 ms per loop
10 loops, best of 3: 64.9 ms per loop
1000 loops, best of 3: 1.62 ms per loop
1 loops, best of 3: 1.7 s per loop
1000 loops, best of 3: 1.42 ms per loop

您不想通过 Python 样式迭代 numpy 数组,而是使用索引访问元素,因为它可以转换为纯 C,而不是依赖 Python API。

于 2013-09-21T22:43:22.773 回答
3

a是无类型的,因此会有很多从 Python 到 C 类型的转换。这些可能很慢。

JoshAdel 正确地指出,您应该迭代一个范围,而不是迭代。Cython 会将索引转换为 C,这很快。


使用cython -a myfile.pyx会为您突出显示这些东西;您希望所有循环逻辑都为白色以获得最大速度。

PS:请注意,它np.ndarray[np.int64_t, ndim=1]已过时,已被弃用,取而代之的是更快、更通用的long[:].

于 2013-09-21T22:07:47.260 回答