问题标签 [jax]

For questions regarding programming in ECMAScript (JavaScript/JS) and its various dialects/implementations (excluding ActionScript). Note JavaScript is NOT the same as Java! Please include all relevant tags on your question; e.g., [node.js], [jquery], [json], [reactjs], [angular], [ember.js], [vue.js], [typescript], [svelte], etc.

0 投票
2 回答
87 浏览

python - 使用 jax.random.normal 对具有特定均值和标准差的单变量高斯分布进行采样

我正在尝试从具有特定标准偏差和均值的高斯中采样,我知道以下函数是从均值为零且标准偏差等于 1 的高斯中采样:

我可以通过这样做来调整平均值:x1 = x1 + mu,但是如何调整标准偏差?

0 投票
2 回答
416 浏览

python - Jax、jit 和动态形状:Tensorflow 的回归?

JAX的文档说,

并非所有 JAX 代码都可以 JIT 编译,因为它要求数组形状是静态的并且在编译时已知。

现在我有点惊讶,因为 tensorflow 具有类似tf.boolean_maskJAX 在编译时似乎无法执行的操作。

  1. 为什么 TensorFlow 会出现这样的回归?我假设底层 XLA 表示在两个框架之间共享,但我可能弄错了。我不记得 Tensorflow 曾经在动态形状方面遇到过问题,而且诸如此类的功能tf.boolean_mask已经存在了很久。
  2. 我们可以期待这种差距在未来缩小吗?如果不是,为什么在 JAX 的 jit 中无法实现 Tensorflow(以及其他)所支持的功能?

编辑

梯度通过tf.boolean_mask(显然不在掩码值上,它们是离散的);此处使用值未知的 TF1 样式图为例,因此 TF 不能依赖它们:

0 投票
0 回答
57 浏览

python - 如何在 JAX 中从数据集中读取数据

我是 JAX 的新手。我在下面有这段代码,其中“特征矩阵”作为数组,“目标向量”作为数组。但我不希望程序读取这些数据数组。这些数组已经存在于代码中。我想修改代码,以便它读取我已导入的波士顿房价数据集。有人可以告诉我需要对此代码进行哪些更改,以便使线性回归起作用吗?

0 投票
2 回答
655 浏览

python - 使用 Jax 的偏导数?

我对 Jax 文档感到困惑,这就是我想要做的:

和错误:

我参考官方教程代码:

结果:

我在这里做错了什么?我收集key到正在以某种重要的方式使用,但我无法弄清楚为什么/如何它是必要的。要回答这个问题,请根据需要调整第一个块中的代码以消除错误。

0 投票
1 回答
45 浏览

python - 两个矩阵行的所有成对叉积

我想有效地计算大小为 nx3 和 mx3 的两个矩阵 A 和 B 的行的所有成对叉积。并且理想情况下希望以 einsum 表示法实现这一点。

即输出矩阵C,将是(n X mx 3),

在哪里

C[0][0] = 交叉(n[0],m[0])

C[0][1] = 交叉(n[0],m[1])

...

C[1][0] = 交叉(n[1],m[0])

...

由于我采用的方法,使用 for 循环不是一种选择。

任何帮助将非常感激。

0 投票
1 回答
346 浏览

python - 为什么当我的函数 jax.grad 中有 np.power 时不能给我派生词?

我想训练一个简单的线性模型。x 和 y 下面的这些是我的数据。

f 是计算所有数据的均方误差的函数。

这给了我一个错误:

当我清除np.power代码时。为什么?

0 投票
1 回答
122 浏览

python - 与此 Python 函数等效的 JaxNumpy 兼容是什么?

如何以与 JAX 兼容的方式(例如,使用jax.numpy)实现以下内容?

0 投票
1 回答
2031 浏览

deep-learning - flax (google) 和 dm-haiku (deepmind) 之间的主要区别是什么?

亚麻dm-haiku之间的主要区别是什么?

从他们的描述中:

  • Flax,一个用于 JAX 的神经网络库
  • Haiku,受 Sonnet 启发的 JAX 神经网络库

问题

我应该选择哪个基于 jax 的库来实现,比如说DeepSpeech模型(由 CNN 层 + LSTM 层 + FC 组成)和 ctc-loss?


升级版

从 dm-haiku 的开发者那里找到了关于差异的解释:

Flax 包含更多的电池,并带有优化器、混合精度和一些训练循环(我听说这些是解耦的,你可以根据需要使用多少)。Haiku 旨在解决 NN 模块和状态管理,它将问题的其他部分留给其他库(例如用于优化的 optax)。

Haiku 被设计为将 Sonnet(一个 TF NN 库)移植到 JAX。因此,如果(像 DeepMind)你有大量的 Sonnet+TF 代码,你可能想在 JAX 中使用并且你希望迁移该代码(在任一方向)尽可能容易,那么 Haiku 是一个更好的选择。

我认为否则归结为个人喜好。在 Alphabet 中,每个图书馆都有 100 名研究人员使用,所以我认为无论哪种方式都不会出错。在 DeepMind,我们对 Haiku 进行了标准化,因为它对我们有意义。我建议查看两个库提供的示例代码,看看哪些符合您对结构化实验的偏好。我想你会发现,如果你以后改变主意,将代码从一个库移动到另一个库并不是很复杂。


原来的问题仍然相关。

0 投票
1 回答
147 浏览

python - 为什么我的神经网络使用 Jax 不收敛?

我正在学习 Jax,但遇到了一个奇怪的问题。如果我使用如下代码,

神经网络可以cos(x)很好地逼近函数。

但是如果我自己重写神经网络部分如下

我的神经网络将始终收敛到一个常数,这似乎被局部最小值所困。但是同样的神经网络和第一部分一样工作得很好。我真的很困惑。

唯一的区别应该是初始化、神经网络部分和参数的设置params。我尝试了不同的初始化,这没有区别。不知是不是因为优化的设置params不对,导致无法收敛。

0 投票
0 回答
48 浏览

tensorflow - 如何将 Tensorflow1 代码转换为 JAX 代码

我正在尝试将张量流代码转换为 JAX 代码。我的困难是 Stackoverflow 中几乎没有关于 JAX 的任何材料。以下是我要转换的代码,任何帮助将不胜感激。