2

这是一个基本的例子。

@jax.jit
def block(arg1, arg2):
   for x1 in range(cons1):
       for x2 in range(cons2):
          for x3 in range(cons3):
             --do something--
   return result

当 cons 较小时,编译时间约为一分钟。缺点越大,编译时间就越长——10 分钟。我需要更高的缺点。可以做什么?从我正在阅读的内容来看,循环是原因。它们在编译时展开。有什么解决方法吗?还有 jax.fori_loop。但我不明白如何使用它。有 jax.experimental.loops 模块,但我还是无法理解它。

我对这一切都很陌生。因此,感谢所有帮助。如果您能提供一些如何使用 jax 循环的示例,那将不胜感激。

另外,什么是好的编译时间?几分钟内就可以了吗?在其中一个示例中,编译时间为 262 秒,剩余运行时间约为 0.1-0.2 秒。

运行时的任何收益都会被编译时间所掩盖。

4

2 回答 2

2

JAX 的 JIT 编译器将所有 Python 循环变平。要明白我的意思,请看一下这个简单的函数 run through jax.make_jaxpr,这是一种检查 JAX 的跟踪器如何解释 python 代码的方法(请参阅了解 Jaxprs了解更多信息):

import jax

def f(x):
  for i in range(5):
    x += i
  return x

print(jax.make_jaxpr(f)(0))
# { lambda  ; a.
#   let b = add a 0
#       c = add b 1
#       d = add c 2
#       e = add d 3
#       f = add e 4
#   in (f,) }

请注意,循环是扁平的:每一步都成为发送到 XLA 编译器的显式操作。XLA 编译时间会随着函数中操作数量的增加而增加,因此三重嵌套的 for 循环会导致编译时间过长是有道理的。

那么,如何解决这个问题呢?好吧,不幸的是,答案取决于你--do something--在做什么,所以我无法猜测。

一般来说,最好的选择是使用向量化数组操作,而不是循环这些向量中的值;例如,这是一种添加两个向量的非常慢的方法:

import jax.numpy as jnp

def f_slow(x, y):
  z = []
  for xi, yi in zip(xi, yi):
    z.append(xi + yi)
  return jnp.array(z)

这是一种更快的方法来做同样的事情:

def f_fast(x, y):
  return x + y

如果您的操作不适合矢量化,另一种选择是使用宽松的控制流运算符代替for循环:这会将循环向下推入 XLA。这在 CPU 上可以有相当好的性能,但与等效的向量化数组操作相比,在加速器上速度较慢。

有关 JAX 和 Python 控制流语句(例如forifwhile等)的更多讨论,请参阅JAX - The Sharp Bits : Control Flow

于 2021-09-06T20:33:26.340 回答
0

我不确定这是否与 with 相同numba,但这可能是类似的情况。

当我使用numba.jit编译器并且有大数据输入时,我首先在一些小示例数据上编译函数,然后使用它。

伪代码:

func_being_compiled(small_amount_of_data)  # compile-only purpose
func_being_compiled(large_amount_of_data)

于 2021-09-06T10:23:01.583 回答