0

提前为这个问题的含糊程度表示歉意(不幸的是,我对 jax 跟踪的工作原理知之甚少,无法更准确地表述它),但是:有没有办法将函数或代码块与 jax 跟踪完全隔离?

对于上下文,我具有以下形式的功能:

def f(x, y):
   z = h(y)
   return g(x, z)

本质上,我想在进行任何 jax 转换时调用g(x, z), 并将其视为常量。z但是,设置参数z非常尴尬,因此使用辅助函数h将更易于指定的输入y转换为g. 我希望 jax 将h其视为不可追踪的黑匣子,因此jit(lambda x: f(x, y0))对特定y0对象的操作与首先使用 计算z0 = h(y0)numpy然后执行jit(lambda x: g(x, z0))(以及与grad或任何其他函数转换类似)相同。

在我的代码中,我已经编写h了只使用标准numpy(我认为这可能会导致黑盒行为),但是 的编译时间jit(lambda x: f(x, y0))明显长于jit(lambda x: g(x, z0))for的编译时间z0 = h(y0)。我有一种感觉,编译时间可能与 jax 跟踪中的许多循环有关h,但我不确定。

一些附加说明:

  • 以一种对 jax 友好的方式编写h会很尴尬(输入格式参差不齐,大量循环/条件,输出形状取决于输入值等)并且最终比它的价值更麻烦,因为该函数执行起来非常便宜,我不知道永远不需要区分它(输入数据是基于整数的)。

想法?

为清楚起见编辑添加:我知道如果例如f是顶级功能,则可能有解决方法。在这种情况下,让用户首先调用h以“预编译”对 jax 友好的输入g,然后自由地执行他们想要的任何 jax 转换,这并不是什么大问题lambda x: g(x, z0)。但是,我在想象这样的情况,我们有许多想要链接在一起的函数,它们具有相同的结构f,其中有一些对 jax 不友好的输入/计算,但这些输入将始终被视为 jax计算的一部分。原则上,人们总是可以提取这些预先计算来设置 jax 的东西,但是如果我们有一个这种类型的函数的非平凡集合,它们会相互调用,这似乎很困难。

是否有某种方法可以控制如何f跟踪,以便在跟踪时知道只评估z=h(y)(而不是跟踪h)然后继续跟踪g(x, z)

4

1 回答 1

0
f_jitted = jax.jit(f, static_argnums=1)

static_argnums 参数可能会有所帮助

https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html

您可以使用诸如static_argnumsfor之类的转换参数jit来避免跟踪转换函数的特定参数,尽管代价是更多的重新编译。

于 2021-03-01T09:39:30.330 回答