这是一个基本的例子。
@jax.jit
def block(arg1, arg2):
for x1 in range(cons1):
for x2 in range(cons2):
for x3 in range(cons3):
--do something--
return result
当 cons 较小时,编译时间约为一分钟。缺点越大,编译时间就越长——10 分钟。我需要更高的缺点。可以做什么?从我正在阅读的内容来看,循环是原因。它们在编译时展开。有什么解决方法吗?还有 jax.fori_loop。但我不明白如何使用它。有 jax.experimental.loops 模块,但我还是无法理解它。
我对这一切都很陌生。因此,感谢所有帮助。如果您能提供一些如何使用 jax 循环的示例,那将不胜感激。
另外,什么是好的编译时间?几分钟内就可以了吗?在其中一个示例中,编译时间为 262 秒,剩余运行时间约为 0.1-0.2 秒。
运行时的任何收益都会被编译时间所掩盖。