7

我最近偶然发现了numba,并考虑将一些自制的 C 扩展替换为更优雅的 autojitted python 代码。不幸的是,当我尝试第一个快速基准测试时,我并不高兴。似乎 numba 在这里的表现并不比普通的 python 好多少,尽管我本来期望接近 C 的性能:

from numba import jit, autojit, uint, double
import numpy as np
import imp
import logging
logging.getLogger('numba.codegen.debug').setLevel(logging.INFO)

def sum_accum(accmap, a):
    res = np.zeros(np.max(accmap) + 1, dtype=a.dtype)
    for i in xrange(len(accmap)):
        res[accmap[i]] += a[i]
    return res

autonumba_sum_accum = autojit(sum_accum)
numba_sum_accum = jit(double[:](int_[:], double[:]), 
                      locals=dict(i=uint))(sum_accum)

accmap = np.repeat(np.arange(1000), 2)
np.random.shuffle(accmap)
accmap = np.repeat(accmap, 10)
a = np.random.randn(accmap.size)

ref = sum_accum(accmap, a)
assert np.all(ref == numba_sum_accum(accmap, a))
assert np.all(ref == autonumba_sum_accum(accmap, a))

%timeit sum_accum(accmap, a)
%timeit autonumba_sum_accum(accmap, a)
%timeit numba_sum_accum(accmap, a)

accumarray = imp.load_source('accumarray', '/path/to/accumarray.py')
assert np.all(ref == accumarray.accum(accmap, a))

%timeit accumarray.accum(accmap, a)

这在我的机器上给出:

10 loops, best of 3: 52 ms per loop
10 loops, best of 3: 42.2 ms per loop
10 loops, best of 3: 43.5 ms per loop
1000 loops, best of 3: 321 us per loop

我正在运行 pypi 的最新 numba 版本,0.11.0。任何建议,如何修复代码,以便使用 numba 运行得相当快?

4

2 回答 2

6

我自己想通了。np.max(accmap)即使 accmap 的类型设置为 int,numba 也无法确定结果的类型。这不知何故减慢了一切,但修复很容易:

@autojit(locals=dict(reslen=uint))
def sum_accum(accmap, a):
    reslen = np.max(accmap) + 1
    res = np.zeros(reslen, dtype=a.dtype)
    for i in range(len(accmap)):
        res[accmap[i]] += a[i]
    return res

结果相当可观,大约是 C 版本的 2/3:

10000 loops, best of 3: 192 us per loop
于 2013-12-19T11:42:13.367 回答
4
@autojit
def numbaMax(arr):
    MAX = arr[0]
    for i in arr:
        if i > MAX:
            MAX = i
    return MAX

@autojit
def autonumba_sum_accum2(accmap, a):
    res = np.zeros(numbaMax(accmap) + 1)
    for i in xrange(len(accmap)):
        res[accmap[i]] += a[i]
    return res

10 loops, best of 3: 26.5 ms per loop <- original
100 loops, best of 3: 15.1 ms per loop <- with numba but the slow numpy max
10000 loops, best of 3: 47.9 µs per loop <- with numbamax
于 2013-12-19T19:59:33.483 回答