1

这是一个简单的 JAX 代码,展示了 Metropolis 算法在解决 3 参数贝叶斯回归 pb 的实际操作中。即使在 CPU 上运行 wo JIT 编译也可以。现在我想知道为什么当关于 JIT 的 2 行被取消时,CPU(Jit 或非 JIT)和在 CPU 或 K80/Nvidia GPU 上运行的比较时间并没有真正不同?

我可能以错误/低效的方式编码吗?

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import jacfwd, jacrev, hessian
from jax.ops import index, index_update
from functools import partial

import scipy.stats as scs
import numpy as np

#@partial(jax.jit, static_argnums=(1,))
def jax_metropolis_kernel(rng_key, logpdf, position, log_prob):
    key, subkey = jax.random.split(rng_key)
    """Moves the chain by one step using the Random Walk Metropolis algorithm.
    """
    move_proposals = jax.random.normal(key, shape=position.shape) * 0.1
        
    proposal = position + move_proposals
    proposal_log_prob = logpdf(proposal)

    log_uniform = jnp.log(jax.random.uniform(subkey))
    do_accept = log_uniform < proposal_log_prob - log_prob

    position = jnp.where(do_accept, proposal, position)
    log_prob = jnp.where(do_accept, proposal_log_prob, log_prob)
    return position, log_prob

#@partial(jax.jit, static_argnums=(1, 2))
def jax_metropolis_sampler(rng_key, n_samples, logpdf, initial_position):
    """Generate samples using the Random Walk Metropolis algorithm.
    """
    def mh_update(i, state):
        key, positions, log_prob = state
        _, key = jax.random.split(key)
                
        new_position, new_log_prob = jax_metropolis_kernel(key, 
                                                           logpdf, 
                                                           positions[i-1], 
                                                           log_prob)
                
        
        positions=positions.at[i].set(new_position)
        return (key, positions, new_log_prob)

    #Initialisation
    keys = jax.random.split(rng_key,num=4)
    all_positions = jnp.zeros((n_samples,initial_position.shape[0]))  # 1 chain for each vmap call    ?
#    all_positions=all_positions.at[0,0].set(scs.norm.rvs(loc=1,scale=1))
#    all_positions=all_positions.at[0,1].set(scs.norm.rvs(loc=2,scale=1))
#    all_positions=all_positions.at[0,2].set(scs.uniform.rvs(loc=1,scale=2))
    
    all_positions=all_positions.at[0,0].set(jax.random.normal(keys[0])+1.)
    all_positions=all_positions.at[0,1].set(jax.random.normal(keys[1])+2.)
    all_positions=all_positions.at[0,2].set(jax.random.uniform(keys[2],minval=1.0, maxval=3.0))

    logp = logpdf(all_positions[0])
    
    initial_state = (rng_key,all_positions, logp)
    rng_key, all_positions, log_prob = jax.lax.fori_loop(1, n_samples, 
                                                 mh_update, 
                                                 initial_state)
    
    return all_positions

def jax_my_logpdf(par,xi,yi):
    # priors: a=par[0], b=par[1], sigma=par[2]
    logpdf_a = jax.scipy.stats.norm.logpdf(x=par[0],loc=1.,scale=1.)
    logpdf_b = jax.scipy.stats.norm.logpdf(x=par[1],loc=2.,scale=1.)
    logpdf_s = jax.scipy.stats.gamma.logpdf(x=par[2],a=3,scale=1.)

    val = xi*par[1]+par[0]
    tmp = jax.scipy.stats.norm.logpdf(x=val,loc=yi,scale=par[2])    
    log_likeh= jnp.sum(tmp)
    
    rc = log_likeh + logpdf_a + logpdf_b + logpdf_s

    return log_likeh + logpdf_a + logpdf_b + logpdf_s

######## Main ########
n_dim = 3
n_forget = 1_000
n_samples = 100_000 + n_forget
n_chains = 100
rng_key = jax.random.PRNGKey(42)

# generation of (xi,yi) set
sample_size = 5_000
sigma_e = 1.5             # true value of parameter error sigma
random_num_generator = np.random.RandomState(0)
xi = 10.0 * random_num_generator.rand(sample_size)
e = random_num_generator.normal(0, sigma_e, sample_size)
yi = 1.0 + 2.0 * xi +  e          # a = 1.0; b = 2.0; y = a + b*x


rng_keys = jax.random.split(rng_key, n_chains)    # generate an array of size (n_chains, 2)
initial_position = jnp.ones((n_dim, n_chains))    # generate an array of size (n_dim, n_chains)
                                                  # so for vmap one should connect axis 0 of rng_keys  
                                                  # and axis 1 of initial_position

#print("main initial_position shape",initial_position.shape)

run_mcmc = jax.vmap(jax_metropolis_sampler, 
                    in_axes=(0, None, None, 1),   # see comment above 
                    out_axes=0)                   # output axis 0 hold the vectorization over n_chains
                                                  # => (n_chains, n_samples, n_dims)


all_positions = run_mcmc(rng_keys, n_samples, 
                     lambda par: jax_my_logpdf(par,xi,yi), 
                     initial_position)

然后一旦代码被调用一次就可以了

%timeit all_positions = run_mcmc(rng_keys, n_samples, 
                     lambda par: jax_my_logpdf(par,xi,yi), 
                     initial_position)

没有 JIT 的 CPU 时间(即@partial 行评论)我得到 1 分 27 秒,而使用 JIT 我得到 1 分 20 秒(两个结果都是 7 次运行的平均值)感谢您的建议。

4

1 回答 1

1

JIT 编译在这里没有给你任何加速的原因是因为你的大部分计算发生在你传递给的函数中fori_loop,这是默认情况下编译的,所以在相对意义上,从 JIT 编译中获得的收益并不多剩下的步骤。

至于为什么您的计算需要几分钟才能执行:您使用的是fori_loop101,000 步,并且在每一步中都做了相当大量的工作。您所看到的只是为您指定的输入运行代码需要多长时间。

于 2021-09-23T18:34:33.167 回答