问题标签 [jax]
For questions regarding programming in ECMAScript (JavaScript/JS) and its various dialects/implementations (excluding ActionScript). Note JavaScript is NOT the same as Java! Please include all relevant tags on your question; e.g., [node.js], [jquery], [json], [reactjs], [angular], [ember.js], [vue.js], [typescript], [svelte], etc.
python - 使用 jax.random.normal 对具有特定均值和标准差的单变量高斯分布进行采样
我正在尝试从具有特定标准偏差和均值的高斯中采样,我知道以下函数是从均值为零且标准偏差等于 1 的高斯中采样:
我可以通过这样做来调整平均值:x1 = x1 + mu
,但是如何调整标准偏差?
python - Jax、jit 和动态形状:Tensorflow 的回归?
JAX的文档说,
并非所有 JAX 代码都可以 JIT 编译,因为它要求数组形状是静态的并且在编译时已知。
现在我有点惊讶,因为 tensorflow 具有类似tf.boolean_mask
JAX 在编译时似乎无法执行的操作。
- 为什么 TensorFlow 会出现这样的回归?我假设底层 XLA 表示在两个框架之间共享,但我可能弄错了。我不记得 Tensorflow 曾经在动态形状方面遇到过问题,而且诸如此类的功能
tf.boolean_mask
已经存在了很久。 - 我们可以期待这种差距在未来缩小吗?如果不是,为什么在 JAX 的 jit 中无法实现 Tensorflow(以及其他)所支持的功能?
编辑
梯度通过tf.boolean_mask
(显然不在掩码值上,它们是离散的);此处使用值未知的 TF1 样式图为例,因此 TF 不能依赖它们:
python - 如何在 JAX 中从数据集中读取数据
我是 JAX 的新手。我在下面有这段代码,其中“特征矩阵”作为数组,“目标向量”作为数组。但我不希望程序读取这些数据数组。这些数组已经存在于代码中。我想修改代码,以便它读取我已导入的波士顿房价数据集。有人可以告诉我需要对此代码进行哪些更改,以便使线性回归起作用吗?
python - 使用 Jax 的偏导数?
我对 Jax 文档感到困惑,这就是我想要做的:
和错误:
我参考官方教程代码:
结果:
我在这里做错了什么?我收集key
到正在以某种重要的方式使用,但我无法弄清楚为什么/如何它是必要的。要回答这个问题,请根据需要调整第一个块中的代码以消除错误。
python - 两个矩阵行的所有成对叉积
我想有效地计算大小为 nx3 和 mx3 的两个矩阵 A 和 B 的行的所有成对叉积。并且理想情况下希望以 einsum 表示法实现这一点。
即输出矩阵C,将是(n X mx 3),
在哪里
C[0][0] = 交叉(n[0],m[0])
C[0][1] = 交叉(n[0],m[1])
...
C[1][0] = 交叉(n[1],m[0])
...
由于我采用的方法,使用 for 循环不是一种选择。
任何帮助将非常感激。
python - 为什么当我的函数 jax.grad 中有 np.power 时不能给我派生词?
我想训练一个简单的线性模型。x 和 y 下面的这些是我的数据。
f 是计算所有数据的均方误差的函数。
这给了我一个错误:
当我清除np.power
代码时。为什么?
python - 与此 Python 函数等效的 JaxNumpy 兼容是什么?
如何以与 JAX 兼容的方式(例如,使用jax.numpy
)实现以下内容?
deep-learning - flax (google) 和 dm-haiku (deepmind) 之间的主要区别是什么?
从他们的描述中:
- Flax,一个用于 JAX 的神经网络库
- Haiku,受 Sonnet 启发的 JAX 神经网络库
问题:
我应该选择哪个基于 jax 的库来实现,比如说DeepSpeech模型(由 CNN 层 + LSTM 层 + FC 组成)和 ctc-loss?
升级版。
从 dm-haiku 的开发者那里找到了关于差异的解释:
Flax 包含更多的电池,并带有优化器、混合精度和一些训练循环(我听说这些是解耦的,你可以根据需要使用多少)。Haiku 旨在解决 NN 模块和状态管理,它将问题的其他部分留给其他库(例如用于优化的 optax)。
Haiku 被设计为将 Sonnet(一个 TF NN 库)移植到 JAX。因此,如果(像 DeepMind)你有大量的 Sonnet+TF 代码,你可能想在 JAX 中使用并且你希望迁移该代码(在任一方向)尽可能容易,那么 Haiku 是一个更好的选择。
我认为否则归结为个人喜好。在 Alphabet 中,每个图书馆都有 100 名研究人员使用,所以我认为无论哪种方式都不会出错。在 DeepMind,我们对 Haiku 进行了标准化,因为它对我们有意义。我建议查看两个库提供的示例代码,看看哪些符合您对结构化实验的偏好。我想你会发现,如果你以后改变主意,将代码从一个库移动到另一个库并不是很复杂。
原来的问题仍然相关。
python - 为什么我的神经网络使用 Jax 不收敛?
我正在学习 Jax,但遇到了一个奇怪的问题。如果我使用如下代码,
神经网络可以cos(x)
很好地逼近函数。
但是如果我自己重写神经网络部分如下
我的神经网络将始终收敛到一个常数,这似乎被局部最小值所困。但是同样的神经网络和第一部分一样工作得很好。我真的很困惑。
唯一的区别应该是初始化、神经网络部分和参数的设置params
。我尝试了不同的初始化,这没有区别。不知是不是因为优化的设置params
不对,导致无法收敛。
tensorflow - 如何将 Tensorflow1 代码转换为 JAX 代码
我正在尝试将张量流代码转换为 JAX 代码。我的困难是 Stackoverflow 中几乎没有关于 JAX 的任何材料。以下是我要转换的代码,任何帮助将不胜感激。