问题标签 [jax]

For questions regarding programming in ECMAScript (JavaScript/JS) and its various dialects/implementations (excluding ActionScript). Note JavaScript is NOT the same as Java! Please include all relevant tags on your question; e.g., [node.js], [jquery], [json], [reactjs], [angular], [ember.js], [vue.js], [typescript], [svelte], etc.

0 投票
0 回答
35 浏览

python - 简单的误解,神经切线入门

我复制并简化了这段极其简单的代码,只是为了开始使用神经切线和 jax

据我了解,这训练了一个“无限宽”的神经网络,以将一个训练示例实数拟合到所需的目标(也是单个实数)。

所以ypredictions应该是一样的。我正在训练一个例子,我有一个无限强大的模型,在我的脑海中这两个应该是相同的。他们不是。它打印:

更重要的是,如果我将训练示例的数量更改为 3,那么现在两个打印的尺寸不匹配!我希望找到两个向量,每个向量包含三个数字。我得到的是:

很明显,我有一个致命的误解。文档没有帮助我。任何人都可以阐明这个问题吗?

0 投票
1 回答
200 浏览

python - 选择 JAX 矩阵子集的最快方法是什么?

假设我有一个二维矩阵,我想在直方图中绘制它的值。为此,我需要执行以下操作:

然后使用列表绘制直方图。到目前为止一切顺利,只是原始矩阵中有我想排除的项目。为简单起见,假设我有一个这样的列表:

因此,list_1d应该有矩阵中的所有项目,而没有指向的exclude项目(的项目exclude是行和列索引)。

顺便说一句,这matrix_2d是一个 JAX 数组,这意味着它的内容在 GPU 中。

0 投票
0 回答
89 浏览

python - 在 gunicorn/flask 服务器中使用 google 的 JAX

我想提供一个应用程序,该应用程序使用烧瓶和 gunicorn 在 googles JAX 框架中处理数据。

如果在烧瓶内运行,一切正常。一旦我在 gunicorn 中运行应用程序,每个与 jax 相关的部分都会导致工作进程死亡,而不会引发任何异常。我尝试同时使用同步和 gthreads 作为工作线程,但结果相同。

我试图通过在 ThreadPoolExecutor 和 ProcessPoolExecutor 中包装相同的调用来查看 JAX 是否可以处理多处理和多线程,并且可以完美地工作。

在调试期间,每次我检查 JAX DeviceArray 时,应用程序都会崩溃。使用 JAX 跳过第一个计算也是如此。

任何帮助将非常感激!

0 投票
1 回答
91 浏览

python - 使用 numpy 和 jax 进行非传递子类化

我的问题很简单:

?

现在我会闲逛,所以 SE 会接受我的合理问题。

0 投票
1 回答
384 浏览

python - Jax 中的 vmap ops.index_update

我在下面有以下代码,它使用了一个简单的 for 循环。我只是想知道是否有办法 vmap 它?这是原始代码:

这是我使用 vmap 的尝试:

但我收到以下错误:

TypeError: vmap in_axes 必须是一个 int、None 或(嵌套)容器,这些类型作为叶子,但得到了 Traced<ShapedArray(float32[192,334])>with<DynamicJaxprTrace(level=0/1)>。

我有点困惑,因为范围是 int 类型,所以我不太确定发生了什么。

最后,我试图让这个小部分尽可能地优化,以获得最短的时间。

0 投票
1 回答
31 浏览

performance - 对矩阵元素求幂的两种方法的比较

我有两种对jnp = jax.numpy. 一个直截了当的:

还有一些额外的动作:

但是,当我测试它们时:

尽管从表面上看有一些额外的开销,但第二种方法表现得更好。我运行了%timeit一个大小为 2000 x 2000 的矩阵:

为什么会这样?

0 投票
2 回答
440 浏览

python - 切片 jax.numpy 数组时性能下降

在尝试对大型数组进行 SVD 压缩时,我在 Jax 中遇到了一些我不理解的行为。这是示例代码:

在考虑这段代码时,Jax/jit 比 SciPy 提供了巨大的性能提升,但最终我想减少 U 的维数,我通过将它包装在函数中来做到这一点:

这一步在计算时间方面的成本令人难以置信,比 SciPy 的等价物更昂贵,从这个比较中可以看出:

jax 和 scipy 的基准测试

sc_compress并且sc_process是上面 jax 代码的 SciPy 等价物。如您所见,在 SciPy 中对数组进行切片几乎不需要任何成本,但在应用于 hit 函数的输出时却非常昂贵。有人对这种行为有一些了解吗?

0 投票
0 回答
93 浏览

openmdao - OpenMDAO 可以与 autograd 或 jax 合作吗?

是否可以使用 autograd 或 jax 包为 OpenMDAO 显式组件生成等效的解析导数?即比有限差分更准确(或者可能比复杂步骤方法更准确或更通用?)但没有手动计算和编程分析梯度的工作?

我不是这些软件包的专家,但它们似乎就是为此目的而设计的。

0 投票
1 回答
900 浏览

python - Jaxlib pip 安装失败

从命令行,我尝试按照这个安装教程进行操作,如果可能的话,我想避免从源代码构建。目前,我不确定问题是什么。任何人都可以验证他们在尝试安装 Jaxlib 时得到相同/不同的响应吗?

出于意识,Jax 安装良好,没有任何问题,但是在单独安装的 Jaxlib 中找到了一些支持组件。

0 投票
1 回答
234 浏览

cuda - 在以下位置找不到库:/usr/local/cuda-9.0/targets/aarch64-linux/lib/libcublasLt.so.9.0

我正在尝试在 NVIDIA Jetson TX2 上安装 JAX,但遇到了相当大的问题。

我有 CUDA 9.0,它给了我以下错误:

所以我去寻找,当然那个图书馆不存在。有人对我如何安装该库有任何指示吗?我试过搜索谷歌,它似乎根本不存在。