我正在学习使用 JAX,但我对使用 JAX 有一些疑问,jit
并且vmap
无法通过阅读文档来解决。
分别对
jit
几个函数和jit
使用它们的函数有影响吗?例如,如果我有函数foo()
和bar()
函数@jax.jit def fooBar(x): return foo(x) + bar(x)
如果
foo()
和bar()
已经被 jitted 有什么区别吗?我应该在我
jit
之后执行一个函数vmap
吗?在上面的例子中,我应该做jax.jit(jax.vmal(fooBar))
还是只做jax.vmap(fooBar)
?