问题标签 [google-jax]

For questions regarding programming in ECMAScript (JavaScript/JS) and its various dialects/implementations (excluding ActionScript). Note JavaScript is NOT the same as Java! Please include all relevant tags on your question; e.g., [node.js], [jquery], [json], [reactjs], [angular], [ember.js], [vue.js], [typescript], [svelte], etc.

0 投票
0 回答
867 浏览

python - 是否可以在 Google 的 Jax 机器学习库中使用对象

我正在尝试使用 Google 的 Jax 机器学习库编写 DC Gan 网络。为此,我创建了对象作为鉴别器和生成器,但是,当我测试鉴别器时,我得到了错误:

我查看了 Jax github 页面上的示例,据我所见,那里的示例都没有使用对象,这使我假设可能无法在 Jax 中使用对象。但如果是这样的话,我真的不明白为什么不能使用对象,这会是将来实现的东西吗?我只是天真地忽略了一些东西吗?

这是我的鉴别器对象:

我在这里更新参数:

0 投票
0 回答
195 浏览

machine-learning - 在向后传递 Google-JAX 中保存梯度

我正在使用JAX来实现一个简单的神经网络 (NN),并且我想在 NN 运行后访问并保存反向传播的梯度以供进一步分析。我可以使用 python 调试器临时访问和查看渐变(只要我不使用 jit)。但我想保存整个训练过程中的所有梯度,并在训练完成后对其进行分析。我为此使用 id_tap 和全局变量想出了一个相当老套的解决方案(参见下面的代码)。但我想知道是否有更好的解决方案不违反 JAX 的功能原则。

非常感谢!

0 投票
1 回答
43 浏览

gpu - 分析 JAX 代码:什么是 redzone_checker,为什么要花这么多时间?

我找到了这篇文章,但仍不清楚redzone_checker内核在做什么以及为什么。具体来说,它是否应该占用我应用程序运行时间的 90% 以上?TensorBoard 报告说它占用了我的 JAX 代码的绝大部分运行时间,我想知道

  1. 实际上是这个内核花费了太多时间,还是这是使用 TensorBoard 分析 JAX 的副作用(即,输出在某种程度上具有误导性)?
  2. 有没有办法减少redzone_checker内核花费的时间?这甚至是个好主意吗?

提前感谢您的任何见解。

0 投票
1 回答
82 浏览

machine-learning - 如何在 google-jax 中使用 grad 卷积?

感谢您阅读我的问题!

我刚刚学习了 Jax 中的自定义 grad 函数,我发现 JAX 用于定义自定义函数的方法非常优雅。

不过有一件事让我很困扰。

我创建了一个包装器,使松散卷积看起来像 PyTorch conv2d。

问题是我找不到使用它的 grad 函数的方法:

这就是我得到的。

请帮忙!

0 投票
0 回答
41 浏览

python - 使用 conda 进行 JAX/XLA 慢速编译

我开始使用 Google JAX 和内置的 jit 和 grad 功能。这些方面在我的机器上运行良好,但是当我增加参数数量时,我收到以下通知:

我很想增加输入参数的数量,所以我想很快我将需要更快的编译时间,所以这个通知很吸引我......但我不明白如何实现它。

我一直在使用 conda 来安装 jax。基本上,我在终端中运行以下命令:

我确定在 conda 中安装时必须有一种方法可以添加一些选项(例如,使用conda install jax=arguments但我在任何地方的文档中都找不到如何操作。堆栈溢出似乎也没有任何内容-搜索只发现以下内容: 使用 jax 时 XLA 的 jit 编译速度非常慢

任何建议将不胜感激!