提前为这个问题的含糊程度表示歉意(不幸的是,我对 jax 跟踪的工作原理知之甚少,无法更准确地表述它),但是:有没有办法将函数或代码块与 jax 跟踪完全隔离?
对于上下文,我具有以下形式的功能:
def f(x, y):
z = h(y)
return g(x, z)
本质上,我想在进行任何 jax 转换时调用g(x, z)
, 并将其视为常量。z
但是,设置参数z
非常尴尬,因此使用辅助函数h
将更易于指定的输入y
转换为g
. 我希望 jax 将h
其视为不可追踪的黑匣子,因此jit(lambda x: f(x, y0))
对特定y0
对象的操作与首先使用 计算z0 = h(y0)
,numpy
然后执行jit(lambda x: g(x, z0))
(以及与grad
或任何其他函数转换类似)相同。
在我的代码中,我已经编写h
了只使用标准numpy
(我认为这可能会导致黑盒行为),但是 的编译时间jit(lambda x: f(x, y0))
明显长于jit(lambda x: g(x, z0))
for的编译时间z0 = h(y0)
。我有一种感觉,编译时间可能与 jax 跟踪中的许多循环有关h
,但我不确定。
一些附加说明:
- 以一种对 jax 友好的方式编写
h
会很尴尬(输入格式参差不齐,大量循环/条件,输出形状取决于输入值等)并且最终比它的价值更麻烦,因为该函数执行起来非常便宜,我不知道永远不需要区分它(输入数据是基于整数的)。
想法?
为清楚起见编辑添加:我知道如果例如f
是顶级功能,则可能有解决方法。在这种情况下,让用户首先调用h
以“预编译”对 jax 友好的输入g
,然后自由地执行他们想要的任何 jax 转换,这并不是什么大问题lambda x: g(x, z0)
。但是,我在想象这样的情况,我们有许多想要链接在一起的函数,它们具有相同的结构f
,其中有一些对 jax 不友好的输入/计算,但这些输入将始终被视为 jax计算的一部分。原则上,人们总是可以提取这些预先计算来设置 jax 的东西,但是如果我们有一个这种类型的函数的非平凡集合,它们会相互调用,这似乎很困难。
是否有某种方法可以控制如何f
跟踪,以便在跟踪时知道只评估z=h(y)
(而不是跟踪h
)然后继续跟踪g(x, z)
?