1

我开始使用 Google JAX 和内置的 jit 和 grad 功能。这些方面在我的机器上运行良好,但是当我增加参数数量时,我收到以下通知:

********************************
Slow compile?  XLA was built without compiler optimizations, which can be slow.  Try rebuilding with -c opt.
Compiling module jit_obj_func__1.9055
********************************

我很想增加输入参数的数量,所以我想很快我将需要更快的编译时间,所以这个通知很吸引我......但我不明白如何实现它。

我一直在使用 conda 来安装 jax。基本上,我在终端中运行以下命令:

    ~$ conda create --name jax
    ~$ conda activate jax
    ~$ conda install -c conda-forge jax matplotlib cudatoolkit

我确定在 conda 中安装时必须有一种方法可以添加一些选项(例如,使用conda install jax=arguments但我在任何地方的文档中都找不到如何操作。堆栈溢出似乎也没有任何内容-搜索只发现以下内容: 使用 jax 时 XLA 的 jit 编译速度非常慢

任何建议将不胜感激!

4

0 回答 0