这里是一个简单的练习,带有一个辛普森集成代码,我已经制作了它来接受几个函数来集成一组边界
import numpy as np
def simps(f, a, b, N):
#N should be even
dx = (b - a) / N
x = np.linspace(a, b, N + 1)
y = f(x)
w = np.ones_like(y)
w[2:-1:2] = 2.
w[1::2] = 4.
S = dx / 3 * np.einsum("i...,i...",w,y)
return S
def funcN(x):
return np.stack([x**(i/10) * np.exp(-x) for i in range(200)],axis=1)
a = np.arange(0,10,0.1)
b = a+0.05
我在 CPU 设备上,然后我得到一个 200 x 100 数字数组,对应于 Int(f_i, a_j,b_j) i:0-199 和 j:0-99
%timeit simps(funcN,a,b, 512)
每个循环 1.13 秒 ± 27.4 毫秒(平均值 ± 标准偏差。7 次运行,每个循环 1 个)
现在考虑以下 JAX/JIT 版本
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from functools import partial
from jax.config import config
config.update("jax_enable_x64", True) #numpy by default is in double precision
@partial(jit, static_argnums=(0,3))
def jax_simps(f, a,b, N):
dx = (b - a) / N
x = jnp.linspace(a, b, N + 1)
y = f(x)
w = jnp.ones_like(y)
w = w.at[2:-1:2].set(2.)
w = w.at[1::2].set(4.)
S = dx / 3. * jnp.einsum('i...,i...',w,y)
return S
@jit
def jax_funcN(x):
return jnp.stack([x**(i/10) * jnp.exp(-x) for i in range(200)],axis=1)
ja = jnp.arange(0,10,0.1)
jb = ja+0.05
#warm up
jax_simps(jax_funcN,ja,jb, 512).block_until_ready()
%timeit jax_simps(jax_funcN,ja,jb, 512).block_until_ready()
我已经验证了这两个代码(纯 Numpy 和 JAX/JIT)给出的结果相同,因为最大相对误差约为 8. 10^-16。
现在,我得到以下时间 933 ms ± 51.4 ms 每个循环(平均值±标准偏差。7 次运行,每个循环 1 个)
这与纯 Numpy 非常接近。我是否偶然制作了一个非常有效的纯 Numpy 代码???还是我以错误的方式编码 JAX/JIT?
(nb. 使用 Google collab K80 GPU 时,每个循环的 JAX/JIT 时间下降到 7.19 毫秒,将纯 Numpy 保持在 1 秒/循环的水平)