问题标签 [jax]
For questions regarding programming in ECMAScript (JavaScript/JS) and its various dialects/implementations (excluding ActionScript). Note JavaScript is NOT the same as Java! Please include all relevant tags on your question; e.g., [node.js], [jquery], [json], [reactjs], [angular], [ember.js], [vue.js], [typescript], [svelte], etc.
python - 用于简单数组更新的 Jax vmap
我是 Jax 的新手,我正在努力转换其他人的代码,该代码使用了 numba “fastmath” 功能并依赖于许多嵌套的 for 循环而没有太多的性能损失。我正在尝试使用 Jax 的 vmap 函数重新创建相同的行为。但是,我目前在一些基本问题上遇到了很多困难。这是我尝试使用 vmap 进行矢量化的简化示例:
我将如何使用 vmap 重写这样的代码?虽然此代码相对容易手动矢量化,但我希望更好地了解 vmap 的工作原理,并希望任何答案都能帮助我。这些文档现在似乎并没有真正帮助我。我非常感谢您能提供的任何帮助。
jit - 是否可以 jit 使用 jax.numpy.unique 的函数?
以下代码不起作用:
错误消息涉及使用jnp.unique
:
关于尖位的文档解释说,如果内部数组的形状取决于参数值,则 jit 不起作用。这正是这里的情况。
根据文档,一个潜在的解决方法是指定静态参数。但这不适用于我的情况。几乎每个函数调用的参数都会改变。我已将我的代码拆分为一个预处理步骤,该步骤执行诸如 this之类jnp.unique
的计算,以及一个可以 jitted 的计算步骤。
但是我还是想问一下,是否有一些我不知道的解决方法?
python - 使用 jax 数组索引到 numpy 数组:错误的错误消息
以下 numpy 代码非常好:
迁移到 jax 后它也可以工作:
现在让我们尝试混合使用 numpy 和 jax:
这会产生以下错误:
如果不支持使用 jax 数组索引到 numpy 数组,那对我来说很好。但是错误信息似乎是错误的。事情变得更加混乱。如果稍微改变形状,代码就可以正常工作。在下面的示例中,我只编辑了从 (30,) 到 (40,) 的索引形状。没有更多错误消息:
我在 cpu 上运行 jax 版本“0.2.12”。这里发生了什么?
python - Jax 和训练神经网络
我是 JAX 的初学者,我正在尝试学习如何训练神经网络。我看过一些博客,但据我所知,没有一个可以轻松训练它的库,比如 sklearn 中的“fit”。我对分类任务感兴趣,你能推荐我任何博客,以便将他/她的算法应用到我的问题中吗?
python - 如何为 softmax 编写 JAX 自定义向量-雅可比积 (vjp)
为了理解 JAX 的反向模式自动差异,我尝试为 softmax 编写一个 custom_vjp,如下所示:
但是当我调用 jacrev 时,我收到关于 VJP 结果结构与 softmax 输入结构不匹配的错误:
但是,当我打印它们都具有形状 (3,) 但 JAX 似乎不同意的形状时,您可以看到?(实际上输入和输出是 3 x 3 矩阵,但这是因为 JAX 试图在 jacrev 中对 JVP 进行 vmap,因此一次性拉回 R(3) 的整个基础(即 3x3 单位矩阵)。
注意:如果我直接使用 jax.grad 或 jax.vjp,我会得到同样的错误。
python - 如何以复数计算 digamma 函数以便在 Tensorflow 中使用此函数(接受输入作为张量)?
我正在尝试使用 jax.scipy.special.digamma 以复数计算 digamma,但是,即使此包的文档说它可能很复杂,它仍然给我这个错误,这是文档所说的:
参数: z (array_like) – 实数或复数参数。
知道如何解决这个问题吗?或者是否有其他方法,例如其他库或其他包,允许我使用复数来计算 digamma 函数!?
python - 在 JAX 中计算词向量的移动平均值的最佳方法
假设我有一个W
形状矩阵,(n_words, model_dim)
其中n_words
是句子中model_dim
的单词数,是表示单词向量的空间的维度。计算这些向量的移动平均值的最快方法是什么?
例如,窗口大小为 2(窗口长度 = 5),我可以有这样的东西(这会引发错误TypeError: JAX 'Tracer' objects do not support item assignment
):
我想有一个更快的解决方案,jnp.convolve
但我不熟悉它。
python - a 和 b 的导数,使用算法微分
我的任务是使用 jax 为这个 函数找到 a 和 b 的导数
现在,我来这里的原因是因为我对 Python 的了解不够多,而对于所讨论的课程,我们也没有被认为是 Python。
任务是:
现在,我能够以正常方式做到这一点:
但我不熟悉算法微分,使用我们给出的例子,我试过这个:
但我收到此错误:
无法解压不可迭代的函数对象
如果有人可以帮助了解我如何做这样的事情,将不胜感激
tensorflow - 不同系数的“基于图”多项式评估的有效方法
我们正在尝试实现一个分段函数,基本上是大约 100 个具有不同系数的多项式,具体取决于 x 的值。
这将在 TensorFlow 或带有 JIT 的 jax 中实现,并针对数据数组进行优化。问题是实现这一目标的最佳方法可能是什么?
可以使用一百个 where,但这并不是最佳选择。或使用tf.switch_case
with tf.vectorize_map
(或类似)。
有什么想法吗?
python - 使用 JAX 进行梯度累积
我做了一个简单的脚本来尝试用 JAX 进行梯度累积。这个想法是让大批量大小(例如 64)分成适合 GPU 内存的小块(例如 4)。对于每个块,将存储在 pytree 中的结果梯度添加到当前批次梯度中。仅当计算大批量的所有块时才完成更新。在这个特定示例中,我们只是尝试将随机 512 维向量拟合到具有线性层的随机布尔值。这是脚本:
divide_pytree
我对and有疑问add_pytrees
。它是否真的修改了当前的批次梯度或者我错过了什么?此外,您是否看到此代码有任何速度问题?特别是,我应该使用jax.lax.fori_loop
in 代替传统的 python for 循环吗?
相关链接: