1

我正在学习使用 JAX,但我对使用 JAX 有一些疑问,jit并且vmap无法通过阅读文档来解决。

  1. 分别对jit几个函数和jit使用它们的函数有影响吗?例如,如果我有函数foo()bar()函数

    @jax.jit 
    def fooBar(x):
        return foo(x) + bar(x)
    

    如果foo()bar()已经被 jitted 有什么区别吗?

  2. 我应该在我jit之后执行一个函数vmap吗?在上面的例子中,我应该做jax.jit(jax.vmal(fooBar))还是只做jax.vmap(fooBar)

4

1 回答 1

2

When it comes to performance of code execution, there is no difference between jitting functions separately and jitting once at the outer function (functionally there is one subtle difference: jit-compiling the inner function will wrap the contents in an xla_call primitive, but this makes little to no difference for the final compilation & execution).

When using vmap on the other hand, there is no implicit compilation. vmap(f) will be executed in eager mode, while jit(vmap(f)) will be just-in-time compiled and generally result in faster execution.

于 2021-06-25T15:19:47.890 回答