我正在尝试学习 numba,因此作为入门练习,我编写了一个简单的轨道求解器:
import numba as nb
import numpy as np
from timeit import default_timer as timer
spec = [('x0', nb.types.float64),
('y0', nb.types.float64),
('vx0', nb.types.float64),
('vy0', nb.types.float64),
('mass', nb.types.float64),
('ax', nb.types.float64),
('ay', nb.types.float64),
('x', nb.types.float64[:]),
('y', nb.types.float64[:]),
('vx', nb.types.float64[:]),
('vy', nb.types.float64[:])]
@nb.jitclass(spec)
class CelestialBody():
def __init__(self, x0, y0, vx0, vy0, mass):
self.x0 = x0
self.y0 = y0
self.vx0 = vx0
self.vy0 = vy0
self.mass = mass
self.ax = 0.0
self.ay = 0.0
@nb.jit(nopython=True, cache=True)
def orbit(bodies, delta_t, nsteps):
# Set up position arrays
for j in range(len(bodies)):
bodies[j].x = np.zeros(nsteps, dtype=np.float64)
bodies[j].y = np.zeros(nsteps, dtype=np.float64)
bodies[j].vx = np.zeros(nsteps, dtype=np.float64)
bodies[j].vy = np.zeros(nsteps, dtype=np.float64)
bodies[j].x[0] = bodies[j].x0
bodies[j].y[0] = bodies[j].y0
bodies[j].vx[0] = bodies[j].vx0
bodies[j].vy[0] = bodies[j].vy0
# Loop over every time step (skip 0 since we have x0 and y0)
for i in range(0, nsteps-1):
# Get gravitational acceleration for each body at current time
for j in range(len(bodies)):
# Reset accelerations
bodies[j].ax = 0.0
bodies[j].ay = 0.0
for k in range(len(bodies)):
if j != k:
# Get distance between objects
dx = bodies[j].x[i] - bodies[k].x[i]
dy = bodies[j].y[i] - bodies[k].y[i]
d = np.sqrt(dx**2. + dy**2.)
# Get acceleration
a = -bodies[k].mass / d**2.
# Separate into x and y components
theta = np.arctan2(dy, dx)
bodies[j].ax += a * np.cos(theta)
bodies[j].ay += a * np.sin(theta)
# Update positions
for j in range(len(bodies)):
bodies[j].vx[i+1] += bodies[j].vx[i] + bodies[j].ax * delta_t
bodies[j].vy[i+1] += bodies[j].vy[i] + bodies[j].ay * delta_t
bodies[j].x[i+1] += bodies[j].x[i] + bodies[j].vx[i] * delta_t +\
0.5 * bodies[j].ax * delta_t**2.
bodies[j].y[i+1] += bodies[j].y[i] + bodies[j].vy[i] * delta_t + 0.5 *\
bodies[j].ay * delta_t**2
return bodies
for i in range(10):
# Set up celestial bodies
sun = CelestialBody(0., 0., 0., 0., 1.)
earth = CelestialBody(1., 0., 0., 6.33, 3.00e-6)
bodies = [sun, earth]
# Set up time info
tf = 100.
delta_t = tf / 365.
nsteps = int(tf / delta_t)
# Orbit
start = timer()
bodies = orbit(bodies, delta_t, nsteps)
end = timer()
print('Time to run: %f' % (end - start))
该代码可以在没有 numba 的情况下运行和运行。当我添加 numba 时,我可以 jit 我的类和函数,它运行得很好,提供了很好的加速。但是,当我尝试使用 cache=True 缓存 jitt'ed 函数时,我得到一个 KeyError:
File "/usr/local/lib/python3.6/dist-packages/numba/caching.py", line 482, in save
data_name = overloads[key]
KeyError: ((reflected list(instance.jitclass.CelestialBody#2cef1b8<x0:float64,
y0:float64,vx0:float64,vy0:float64,mass:float64,ax:float64,ay:float64,
x:array(float64, 1d, A),y:array(float64, 1d, A),vx:array(float64, 1d, A),
vy:array(float64, 1d, A)>), float64, int64), ('x86_64-unknown-linux-gnu',
'skylake', '+adx,+aes,+avx,+avx2,-avx512bitalg,-avx512bw,-avx512cd,-avx512dq,
-avx512er,-avx512f,-avx512ifma,-avx512pf,-avx512vbmi,-avx512vbmi2,-avx512vl,
-avx512vnni,-avx512vpopcntdq,+bmi,+bmi2,-cldemote,+clflushopt,-clwb,-clzero,+cmov,
+cx16,+f16c,+fma,-fma4,+fsgsbase,-gfni,+invpcid,-lwp,+lzcnt,+mmx,+movbe,-movdir64b,
-movdiri,-mwaitx,+pclmul,-pconfig,-pku,+popcnt,-prefetchwt1,+prfchw,-ptwrite,
-rdpid,+rdrnd,+rdseed,-rtm,+sahf,+sgx,-sha,-shstk,+sse,+sse2,+sse3,+sse4.1,
+sse4.2,-sse4a,+ssse3,-tbm,-vaes,-vpclmulqdq,-waitpkg,-wbnoinvd,-xop,+xsave,
+xsavec,+xsaveopt,+xsaves'))
我意识到上面的大部分内容都是编译器标志等,可能是不必要的,但我不确定,所以我想我会包括它。
还有一个泡菜错误:
_pickle.PicklingError: Can't pickle <class '__main__.CelestialBody'>: it's not the same object as __main__.CelestialBody
我试过看这个问题,但据我所知,没有导入错误,而且我没有弄乱我正在导入的任何模块。我也没有在 jupyter 笔记本上运行,只是一个终端。我的猜测是它在编译之前和之后与类“签名”有关,并且泡菜对这种变化感到困惑。当不使用类时,我可以让缓存工作。
我正在使用 Python 3.6.7 版、numpy 1.15.4 版和 numba 0.42.1 版
所以,我的问题是是什么导致了这个阻止缓存的泡菜错误?谢谢!