我正在学习使用 JAX,但我对使用 JAX 有一些疑问,jit并且vmap无法通过阅读文档来解决。
分别对
jit几个函数和jit使用它们的函数有影响吗?例如,如果我有函数foo()和bar()函数@jax.jit def fooBar(x): return foo(x) + bar(x)如果
foo()和bar()已经被 jitted 有什么区别吗?我应该在我
jit之后执行一个函数vmap吗?在上面的例子中,我应该做jax.jit(jax.vmal(fooBar))还是只做jax.vmap(fooBar)?