我有两种对jnp = jax.numpy
. 一个直截了当的:
jnp.exp(-X/reg)
还有一些额外的动作:
def exp_reg(X, reg):
K = jnp.empty_like(X)
K = jnp.divide(X, -reg)
return jnp.exp(K)
但是,当我测试它们时:
%timeit jnp.exp(-X/reg).block_until_ready()
%timeit exp_reg(X, reg).block_until_ready()
尽管从表面上看有一些额外的开销,但第二种方法表现得更好。我运行了%timeit
一个大小为 2000 x 2000 的矩阵:
7.85 ms ± 567 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
5.19 ms ± 52.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
为什么会这样?