jax.vmap
并且jax.numpy.vectorize
具有完全不同的语义,并且仅在您的示例中的单个一维输入的情况下恰好相似。
的目的jax.vmap
是将函数映射到沿单个显式轴的一个或多个输入上,由in_axes
参数指定。另一方面,根据 numpy 广播规则,沿零个或多个隐式轴jax.numpy.vectorize
将函数映射到一个或多个输入。
要查看差异,让我们传递两个二维输入并在函数中打印形状:
import jax
import jax.numpy as jnp
def print_shape(x, y):
print(f"x.shape = {x.shape}")
print(f"y.shape = {y.shape}")
return x + y
x = jnp.zeros((20, 10))
y = jnp.zeros((20, 10))
_ = jax.vmap(print_shape)(x, y)
# x.shape = (10,)
# y.shape = (10,)
_ = jnp.vectorize(print_shape)(x, y)
# x.shape = ()
# y.shape = ()
请注意,vmap
仅沿第一个轴vectorize
映射,而沿两个输入轴映射。
还要注意,隐式映射vectorize
意味着它可以更灵活地使用;例如:
x2 = jnp.arange(10)
y2 = jnp.arange(20).reshape(20, 1)
def add(x, y):
# vectorize always maps over all axes, such that the function is applied elementwise
assert x.shape == y.shape == ()
return x + y
jnp.vectorize(add)(x2, y2).shape
# (20, 10)
vectorize
将根据 numpy 广播规则遍历输入的所有轴。另一方面,vmap
默认情况下无法处理:
jax.vmap(add)(x2, y2)
# ValueError: vmap got inconsistent sizes for array axes to be mapped:
# arg 0 has shape (10,) and axis 0 is to be mapped
# arg 1 has shape (20, 1) and axis 0 is to be mapped
# so
# arg 0 has an axis to be mapped of size 10
# arg 1 has an axis to be mapped of size 20
完成同样的操作vmap
需要更多的思考,因为有两个独立的映射轴,并且一些轴是广播的。但是你可以通过这种方式完成同样的事情:
jax.vmap(jax.vmap(add, in_axes=(None, 0)), in_axes=(0, None))(x2, y2[:, 0]).shape
# (20, 10)
后者嵌套vmap
本质上是您使用jax.numpy.vectorize
.
至于在任何给定情况下使用哪个:
- 如果您想在单个明确指定的输入轴上映射函数,请使用
jax.vmap
- 如果您希望根据应用于输入的 numpy 广播规则将函数的输入映射到零个或多个轴上,请使用
jax.numpy.vectorize
.
- 在变换相同的情况下(例如在映射一维输入的函数时)倾向于使用
vmap
,因为它更直接地做你想做的事情。