我试图了解 JAX vmap 的行为,所以我编写了以下代码:
import jax.numpy as jnp
from jax import vmap
def what(a,b,c):
z = jnp.dot(a,b)
return z + c
v_what = vmap(what, in_axes=(None,0,None))
a = jnp.array([1,1,3])
b = jnp.array([2,2])
c = 1.0
v_what(a,b,c)
输出是:
DeviceArray([[3., 3., 7.],
[3., 3., 7.]], dtype=float32)
我知道唯一被改变的输入是b
,但是有人可以解释为什么会这样吗?在我对函数进行矢量化后,点积的行为如何?