我正在尝试优化numpy.packbits
:
import numpy as np
from numba import njit, prange
@njit(parallel=True)
def _numba_pack(arr, div, su):
for i in prange(div):
s = 0
for j in range(i*8, i*8+8):
s = 2*s + arr[j]
su[i] = s
def numba_packbits(arr):
div, mod = np.divmod(arr.size, 8)
su = np.zeros(div + (mod>0), dtype=np.uint8)
_numba_pack(arr[:div*8], div, su)
if mod > 0:
su[-1] = sum(x*y for x,y in zip(arr[div*8:], (128, 64, 32, 16, 8, 4, 2, 1)))
return su
>>> X = np.random.randint(2, size=99, dtype=bool)
>>> print(numba_packbits(X))
[ 75 24 79 61 209 189 203 187 47 226 170 61 0]
它看起来比 . 慢 2 - 2.5 倍np.packbits(X)
。numpy
这在内部是如何实现的?这可以改进numba
吗?
我通过. numpy == 1.21.2
_ 我的平台是:numba == 0.53.1
conda install
结果:
import benchit
from numpy import packbits
%matplotlib inline
benchit.setparams(rep=5)
sizes = [100000, 300000, 1000000, 3000000, 10000000, 30000000]
N = sizes[-1]
arr = np.random.randint(2, size=N, dtype=bool)
fns = [numba_packbits, packbits]
in_ = {s/1000000: (arr[:s], ) for s in sizes}
t = benchit.timings(fns, in_, multivar=True, input_name='Millions of bits')
t.plot(logx=True, figsize=(12, 6), fontsize=14)
更新
Jérôme 的回应是:
@njit('void(bool_[::1], uint8[::1], int_)', inline='never')
def _numba_pack_x64_byJérôme(arr, su, pos):
for i in range(64):
j = i * 8
su[i] = (arr[j]<<7)|(arr[j+1]<<6)|(arr[j+2]<<5)|(arr[j+3]<<4)|(arr[j+4]<<3)|(arr[j+5]<<2)|(arr[j+6]<<1)|arr[j+7]
@njit(parallel=True)
def _numba_pack_byJérôme(arr, div, su):
for i in prange(div//64):
_numba_pack_x64_byJérôme(arr[i*8:(i+64)*8], su[i:i+64], i)
for i in range(div//64*64, div):
j = i * 8
su[i] = (arr[j]<<7)|(arr[j+1]<<6)|(arr[j+2]<<5)|(arr[j+3]<<4)|(arr[j+4]<<3)|(arr[j+5]<<2)|(arr[j+6]<<1)|arr[j+7]
def numba_packbits_byJérôme(arr):
div, mod = np.divmod(arr.size, 8)
su = np.zeros(div + (mod>0), dtype=np.uint8)
_numba_pack_byJérôme(arr[:div*8], div, su)
if mod > 0:
su[-1] = sum(x*y for x,y in zip(arr[div*8:], (128, 64, 32, 16, 8, 4, 2, 1)))
return su
用法:
>>> print(numba_packbits_byJérôme(X))
[ 75 24 79 61 209 189 203 187 47 226 170 61 0]
结果: