问题标签 [jax]

For questions regarding programming in ECMAScript (JavaScript/JS) and its various dialects/implementations (excluding ActionScript). Note JavaScript is NOT the same as Java! Please include all relevant tags on your question; e.g., [node.js], [jquery], [json], [reactjs], [angular], [ember.js], [vue.js], [typescript], [svelte], etc.

0 投票
0 回答
100 浏览

python - 如何在 JAX 中创建非全连接卷积层?

我正在尝试在 JAX 中实现 LeNet5。我想我可以完成大部分工作,这是供参考的架构:

在此处输入图像描述

我想我知道如何实现前两层:

但是,我卡住的地方是 C3 层。问题是 C3 中的通道没有连接到S2 中的每个通道。根据下表,C3 中的 16 个通道中的每一个仅连接到 S2 中的通道子集:

我从(相当简陋的)JAX/STAX文档中看不到如何设置像这样的稀疏连接的卷积层。

0 投票
1 回答
38 浏览

jit - JAX jitting 函数会单独改变性能吗?

我正在学习使用 JAX,但我对使用 JAX 有一些疑问,jit并且vmap无法通过阅读文档来解决。

  1. 分别对jit几个函数和jit使用它们的函数有影响吗?例如,如果我有函数foo()bar()函数

    如果foo()bar()已经被 jitted 有什么区别吗?

  2. 我应该在我jit之后执行一个函数vmap吗?在上面的例子中,我应该做jax.jit(jax.vmal(fooBar))还是只做jax.vmap(fooBar)

0 投票
1 回答
264 浏览

python - 如何从 Jax 中的损失函数返回值字典?

假设您有一个损失函数,并且您想在训练时跟踪损失的各个子组件。这样做的最“jax”方式是什么?

您是否想只使用grad一个函数来拉出损失,然后再次重新计算值?有没有办法value_and_grad更有效地做到这一点?

0 投票
1 回答
690 浏览

python - 从头开始实现二元交叉熵——训练神经网络的结果不一致

我正在尝试使用 JAX 库及其小神经网络子模块“Stax”来实现和训练神经网络。由于这个库没有实现二进制交叉熵,我自己写了:

我实现了一个简单的神经网络并在 MNIST 上对其进行了训练,并开始怀疑我得到的一些结果。所以我在 Keras 中实现了相同的设置,我立即得到了截然不同的结果!相同的模型,在相同的数据上以相同的方式训练,在 Keras 中获得了 90% 的训练准确率,而不是在 JAX 中大约 50%。最终,我将问题的一部分归结为我对交叉熵的幼稚实现,据说它在数值上是不稳定的。在这篇文章和我找到的这段代码之后,我编写了以下新版本:

这工作得更好一些。现在我的 JAX 实现获得了高达 80% 的训练准确率,但这仍然比 Keras 获得的 90% 低很多。我想知道发生了什么?为什么我的两个实现的行为方式不同?

下面,我将我的两个实现浓缩为一个脚本。在这个脚本中,我在 JAX 和 Keras 中实现了相同的模型。我使用相同的权重初始化两者,并使用全批梯度下降对来自 MNIST 的 1000 个数据点进行 10 步训练,每个模型的数据相同。JAX 以 80% 的训练准确率结束,而 Keras 以 90% 结束。具体来说,我得到这个输出:

实际上,当我稍微改变条件(使用不同的随机初始权重或不同的训练集)时,有时我会得到 50% 的 JAX 准确度和 90% 的 Keras 准确度。

我最后交换了权重,以验证从训练中获得的权重确实是问题所在,与网络预测的实际计算或我计算准确性的方式无关。

编码:

尝试将第 57 行的 PRNG 种子更改为其他值,而不是0使用不同的初始权重运行实验。

0 投票
2 回答
95 浏览

tensorflow - 如何只安装 XLA?

我想使用 XLA 作为我项目的后端。有没有推荐的方法来独立安装它(没有 TensorFlow 的其余部分)。Jax 可能会这样做,但在他们的存储库中查看它并不明显。

更新我为此向TensorFlow提出了一张票

0 投票
1 回答
270 浏览

python - Jax - sigmoid 的 autograd 总是返回 nan

我试图区分一个函数,该函数在给定偏移平均值的情况下近似包含在 2 个限制(截断的高斯)内的高斯分数。 jnp.grad不允许我区分添加布尔过滤器(注释行),所以我不得不即兴使用 sigmoid。

但是,现在当截断边界很高时梯度总是 nan 我不明白为什么。

在下面的示例中,我正在计算具有 0 均值和 std=1 的高斯梯度,然后我用x.

如果我减小边界,则函数的行为符合预期。但这不是解决方案。当边界很高时,始终belows 变为 1。但如果是这种情况并且x对下面没有影响,那么它对梯度的贡献应该是0而不是nan。但如果我返回belows[0][0]而不是返回jnp.mean(filt, axis=0),我仍然得到nan

有任何想法吗?提前致谢(github上也有一个问题)

在此处输入图像描述

0 投票
1 回答
247 浏览

python - 不同长度的 JAX 批处理

我有一个函数在compute(x)哪里。现在,我想用它来把它转换成一个需要一批数组的函数,然后加快它的速度。是这样的:xjnp.ndarrayvmapx[i]jitcompute(x)

但是,每个数组x[i]都有不同的长度。我可以通过用尾随零填充数组来轻松解决这个问题,这样它们都具有相同的长度N并且vmap(compute)可以应用于具有 shape 的批次(batch_size, N)

但是,这样做会导致very_expensive_function()在每个数组的尾随零上也被调用x[i]。有没有办法修改compute()这样的,very_expensive_function()只在切片上调用x,而不干扰vmapand jit

0 投票
2 回答
275 浏览

python - 如何在 Jax fori_loop 机制中获取中间结果

我是 Jax 的新手,也不是 Python 专家。

我在我的 Mac 笔记本电脑上运行 jax 版本“0.2.14”。请在下面找到一个简单的代码,至少对我来说给出了一些结果。

但是,正如评论jax_metropolis_sampler方法中所述,我想保存中间结果“位置”,但我不知道正确地使用它jax_fori_loop,我想像我所做的那样做肯定是可怕的。

我很确定有人可以给我一个更好的利用 jax 并行性的解决方案。目前,我还没有研究 MixtureModel_jax 的前向/后向差异。

提前致谢

0 投票
2 回答
436 浏览

python-3.x - 无法在 Google TPU 中导入 python 包 jax

我正在使用 linux 控制台,输入 python 会将我带入 python 控制台。当我在 TPU 机器中使用以下命令时

然后它会生成以下 mss 并退出 python 提示符。

这个问题导致我的代码出现问题,所以我想弄清楚这个问题是什么以及如何解决这个问题?

0 投票
4 回答
1451 浏览

python - 在带有 m1 芯片的 mac 上导入 jax 失败

对于 python 3.8.8 并在 jupyter 笔记本和 python 终端中使用新的 mac air(带有 m1 芯片),import jax会引发此错误

我怀疑它是因为 m1 芯片而发生的。

我尝试使用 jax pip install jax,然后按照评论的建议,通过克隆他们的存储库并按照此处给出的说明从源代码构建它,但显示相同的错误消息。