这是一个简单的 JAX 代码,展示了 Metropolis 算法在解决 3 参数贝叶斯回归 pb 的实际操作中。即使在 CPU 上运行 wo JIT 编译也可以。现在我想知道为什么当关于 JIT 的 2 行被取消时,CPU(Jit 或非 JIT)和在 CPU 或 K80/Nvidia GPU 上运行的比较时间并没有真正不同?
我可能以错误/低效的方式编码吗?
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import jacfwd, jacrev, hessian
from jax.ops import index, index_update
from functools import partial
import scipy.stats as scs
import numpy as np
#@partial(jax.jit, static_argnums=(1,))
def jax_metropolis_kernel(rng_key, logpdf, position, log_prob):
key, subkey = jax.random.split(rng_key)
"""Moves the chain by one step using the Random Walk Metropolis algorithm.
"""
move_proposals = jax.random.normal(key, shape=position.shape) * 0.1
proposal = position + move_proposals
proposal_log_prob = logpdf(proposal)
log_uniform = jnp.log(jax.random.uniform(subkey))
do_accept = log_uniform < proposal_log_prob - log_prob
position = jnp.where(do_accept, proposal, position)
log_prob = jnp.where(do_accept, proposal_log_prob, log_prob)
return position, log_prob
#@partial(jax.jit, static_argnums=(1, 2))
def jax_metropolis_sampler(rng_key, n_samples, logpdf, initial_position):
"""Generate samples using the Random Walk Metropolis algorithm.
"""
def mh_update(i, state):
key, positions, log_prob = state
_, key = jax.random.split(key)
new_position, new_log_prob = jax_metropolis_kernel(key,
logpdf,
positions[i-1],
log_prob)
positions=positions.at[i].set(new_position)
return (key, positions, new_log_prob)
#Initialisation
keys = jax.random.split(rng_key,num=4)
all_positions = jnp.zeros((n_samples,initial_position.shape[0])) # 1 chain for each vmap call ?
# all_positions=all_positions.at[0,0].set(scs.norm.rvs(loc=1,scale=1))
# all_positions=all_positions.at[0,1].set(scs.norm.rvs(loc=2,scale=1))
# all_positions=all_positions.at[0,2].set(scs.uniform.rvs(loc=1,scale=2))
all_positions=all_positions.at[0,0].set(jax.random.normal(keys[0])+1.)
all_positions=all_positions.at[0,1].set(jax.random.normal(keys[1])+2.)
all_positions=all_positions.at[0,2].set(jax.random.uniform(keys[2],minval=1.0, maxval=3.0))
logp = logpdf(all_positions[0])
initial_state = (rng_key,all_positions, logp)
rng_key, all_positions, log_prob = jax.lax.fori_loop(1, n_samples,
mh_update,
initial_state)
return all_positions
def jax_my_logpdf(par,xi,yi):
# priors: a=par[0], b=par[1], sigma=par[2]
logpdf_a = jax.scipy.stats.norm.logpdf(x=par[0],loc=1.,scale=1.)
logpdf_b = jax.scipy.stats.norm.logpdf(x=par[1],loc=2.,scale=1.)
logpdf_s = jax.scipy.stats.gamma.logpdf(x=par[2],a=3,scale=1.)
val = xi*par[1]+par[0]
tmp = jax.scipy.stats.norm.logpdf(x=val,loc=yi,scale=par[2])
log_likeh= jnp.sum(tmp)
rc = log_likeh + logpdf_a + logpdf_b + logpdf_s
return log_likeh + logpdf_a + logpdf_b + logpdf_s
######## Main ########
n_dim = 3
n_forget = 1_000
n_samples = 100_000 + n_forget
n_chains = 100
rng_key = jax.random.PRNGKey(42)
# generation of (xi,yi) set
sample_size = 5_000
sigma_e = 1.5 # true value of parameter error sigma
random_num_generator = np.random.RandomState(0)
xi = 10.0 * random_num_generator.rand(sample_size)
e = random_num_generator.normal(0, sigma_e, sample_size)
yi = 1.0 + 2.0 * xi + e # a = 1.0; b = 2.0; y = a + b*x
rng_keys = jax.random.split(rng_key, n_chains) # generate an array of size (n_chains, 2)
initial_position = jnp.ones((n_dim, n_chains)) # generate an array of size (n_dim, n_chains)
# so for vmap one should connect axis 0 of rng_keys
# and axis 1 of initial_position
#print("main initial_position shape",initial_position.shape)
run_mcmc = jax.vmap(jax_metropolis_sampler,
in_axes=(0, None, None, 1), # see comment above
out_axes=0) # output axis 0 hold the vectorization over n_chains
# => (n_chains, n_samples, n_dims)
all_positions = run_mcmc(rng_keys, n_samples,
lambda par: jax_my_logpdf(par,xi,yi),
initial_position)
然后一旦代码被调用一次就可以了
%timeit all_positions = run_mcmc(rng_keys, n_samples,
lambda par: jax_my_logpdf(par,xi,yi),
initial_position)
没有 JIT 的 CPU 时间(即@partial 行评论)我得到 1 分 27 秒,而使用 JIT 我得到 1 分 20 秒(两个结果都是 7 次运行的平均值)感谢您的建议。