1

我有两种对jnp = jax.numpy. 一个直截了当的:

jnp.exp(-X/reg)

还有一些额外的动作:

def exp_reg(X, reg):
    K = jnp.empty_like(X)
    K = jnp.divide(X, -reg)
    return jnp.exp(K)

但是,当我测试它们时:

%timeit jnp.exp(-X/reg).block_until_ready()
%timeit exp_reg(X, reg).block_until_ready()

尽管从表面上看有一些额外的开销,但第二种方法表现得更好。我运行了%timeit一个大小为 2000 x 2000 的矩阵:

7.85 ms ± 567 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
5.19 ms ± 52.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

为什么会这样?

4

1 回答 1

1

这里的区别在于操作的顺序。

jnp.exp(-X/reg)中,您正在否定 的每个条目X,然后将结果的每个条目除以reg。这是对数组的两次传递X

exp_reg你否定reg(这可能是一个标量值?)然后除以X结果。那是一次通过X

如果X很大,我希望第一种方法比第二种方法稍慢,因为要多次通过X

幸运的是,由于您使用的是 JAX,因此您可以jit编译您的代码,在这种情况下,XLA 通常可以优化这些等价的操作顺序。实际上,对于您的两个函数,编译消除了差异:

from jax import jit
import jax.numpy as jnp
import numpy as np

def exp_reg1(X, reg):
  return jnp.exp(-X/reg)

def exp_reg2(X, reg):
  K = jnp.divide(X, -reg)
  return jnp.exp(K)

X = jnp.array(np.random.rand(1000, 1000))
reg = 2.0

%timeit exp_reg1(X, reg)
# 100 loops, best of 3: 3.17 ms per loop
%timeit exp_reg2(X, reg)
# 100 loops, best of 3: 2.2 ms per loop

# Trigger compilation
jit(exp_reg1)(X, reg)
jit(exp_reg2)(X, reg)

%timeit jit(exp_reg1)(X, reg)
# 1000 loops, best of 3: 1.92 ms per loop
%timeit jit(exp_reg2)(X, reg)
# 100 loops, best of 3: 1.84 ms per loop

(旁注:没有理由K在将操作结果分配给同名变量之前预先分配一个空数组)。

于 2020-11-04T18:26:08.270 回答