2

我试图了解 JAX vmap 的行为,所以我编写了以下代码:

import jax.numpy as jnp
from jax import vmap

def what(a,b,c):
  z = jnp.dot(a,b)
  return z + c

v_what = vmap(what, in_axes=(None,0,None))

a = jnp.array([1,1,3])
b = jnp.array([2,2])
c = 1.0

v_what(a,b,c)

输出是:

DeviceArray([[3., 3., 7.],
             [3., 3., 7.]], dtype=float32)

我知道唯一被改变的输入是b,但是有人可以解释为什么会这样吗?在我对函数进行矢量化后,点积的行为如何?

4

1 回答 1

2

您已指定转换后的函数应映射到 的第一个轴上b,而不是映射到a或的任何轴上c。粗略地说,您已经创建了一个映射函数来执行此操作:

def v_what(a, b, c):
  return jnp.stack([what(a, b_i, c) for b_i in b], axis=0)

对于您的输入,每行中的点积看起来像jnp.dot(a, 2),结果相当于a * 2

于 2021-03-09T14:39:44.920 回答