1

我是新手jax。当我阅读文档时,我对jit.

缓存部分,它说“避免在循环内调用 jax.jit。这样做有效地在每次调用时创建一个新的 f ,每次都会编译它而不是重用相同的缓存函数”。但是,运行以下代码只会产生一种打印副作用:

import jax
def unjitted_loop_body(prev_i):
  print("tracing...")
  return prev_i + 1

def g_inner_jitted_poorly(x, n):
  i = 0
  while i < n:
    # Don't do this!
    i = jax.jit(unjitted_loop_body)(i)
  return x + i

g_inner_jitted_poorly(10, 20)
# output:
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
tracing...
Out[1]: DeviceArray(30, dtype=int32)

字符串“tracing...”只打印一次,似乎jit不再跟踪函数。

这是故意的吗?谢谢你的帮助!

4

0 回答 0