3

jax.numpy.vectorize和 和有什么区别jax.vmap?这是一个小片段

import jax
import jax.numpy as jnp

def f(x):
     return jnp.exp(-x)*jnp.sin(x)

gf = jax.grad(f)
x = jnp.arange(0,1,0.1)

jax.vmap(gf)(x)
jnp.vectorize(gf)(x)

两种计算都给出相同的结果:

DeviceArray([ 1. , 0.80998397, 0.63975394, 0.4888039 , 0.35637075, 0.24149445, 0.14307144, 0.05990037, -0.00927836, -0.06574923]), dtype=float32

如何决定使用哪一个,以及在性能方面是否存在差异?

4

1 回答 1

4

jax.vmap并且jax.numpy.vectorize具有完全不同的语义,并且仅在您的示例中的单个一维输入的情况下恰好相似。

的目的jax.vmap是将函数映射到沿单个显式轴的一个或多个输入上,由in_axes参数指定。另一方面,根据 numpy 广播规则,沿零个或多个隐式轴jax.numpy.vectorize将函数映射到一个或多个输入。

要查看差异,让我们传递两个二维输入并在函数中打印形状:

import jax
import jax.numpy as jnp

def print_shape(x, y):
  print(f"x.shape = {x.shape}")
  print(f"y.shape = {y.shape}")
  return x + y

x = jnp.zeros((20, 10))
y = jnp.zeros((20, 10))

_ = jax.vmap(print_shape)(x, y)
# x.shape = (10,)
# y.shape = (10,)

_ = jnp.vectorize(print_shape)(x, y)
# x.shape = ()
# y.shape = ()

请注意,vmap仅沿第一个轴vectorize映射,而沿两个输入轴映射。

还要注意,隐式映射vectorize意味着它可以更灵活地使用;例如:

x2 = jnp.arange(10)
y2 = jnp.arange(20).reshape(20, 1)

def add(x, y):
  # vectorize always maps over all axes, such that the function is applied elementwise
  assert x.shape == y.shape == ()
  return x + y

jnp.vectorize(add)(x2, y2).shape
# (20, 10)

vectorize将根据 numpy 广播规则遍历输入的所有轴。另一方面,vmap默认情况下无法处理:

jax.vmap(add)(x2, y2)
# ValueError: vmap got inconsistent sizes for array axes to be mapped:
# arg 0 has shape (10,) and axis 0 is to be mapped
# arg 1 has shape (20, 1) and axis 0 is to be mapped
# so
# arg 0 has an axis to be mapped of size 10
# arg 1 has an axis to be mapped of size 20

完成同样的操作vmap需要更多的思考,因为有两个独立的映射轴,并且一些轴是广播的。但是你可以通过这种方式完成同样的事情:

jax.vmap(jax.vmap(add, in_axes=(None, 0)), in_axes=(0, None))(x2, y2[:, 0]).shape
# (20, 10)

后者嵌套vmap本质上是您使用jax.numpy.vectorize.

至于在任何给定情况下使用哪个:

  • 如果您想在单个明确指定的输入轴上映射函数,请使用jax.vmap
  • 如果您希望根据应用于输入的 numpy 广播规则将函数的输入映射到零个或多个轴上,请使用jax.numpy.vectorize.
  • 在变换相同的情况下(例如在映射一维输入的函数时)倾向于使用vmap,因为它更直接地做你想做的事情。
于 2021-09-09T13:11:57.150 回答