问题标签 [jax]
For questions regarding programming in ECMAScript (JavaScript/JS) and its various dialects/implementations (excluding ActionScript). Note JavaScript is NOT the same as Java! Please include all relevant tags on your question; e.g., [node.js], [jquery], [json], [reactjs], [angular], [ember.js], [vue.js], [typescript], [svelte], etc.
python - 简单的误解,神经切线入门
我复制并简化了这段极其简单的代码,只是为了开始使用神经切线和 jax
据我了解,这训练了一个“无限宽”的神经网络,以将一个训练示例实数拟合到所需的目标(也是单个实数)。
所以y
和predictions
应该是一样的。我正在训练一个例子,我有一个无限强大的模型,在我的脑海中这两个应该是相同的。他们不是。它打印:
更重要的是,如果我将训练示例的数量更改为 3,那么现在两个打印的尺寸不匹配!我希望找到两个向量,每个向量包含三个数字。我得到的是:
很明显,我有一个致命的误解。文档没有帮助我。任何人都可以阐明这个问题吗?
python - 选择 JAX 矩阵子集的最快方法是什么?
假设我有一个二维矩阵,我想在直方图中绘制它的值。为此,我需要执行以下操作:
然后使用列表绘制直方图。到目前为止一切顺利,只是原始矩阵中有我想排除的项目。为简单起见,假设我有一个这样的列表:
因此,list_1d
应该有矩阵中的所有项目,而没有指向的exclude
项目(的项目exclude
是行和列索引)。
顺便说一句,这matrix_2d
是一个 JAX 数组,这意味着它的内容在 GPU 中。
python - 在 gunicorn/flask 服务器中使用 google 的 JAX
我想提供一个应用程序,该应用程序使用烧瓶和 gunicorn 在 googles JAX 框架中处理数据。
如果在烧瓶内运行,一切正常。一旦我在 gunicorn 中运行应用程序,每个与 jax 相关的部分都会导致工作进程死亡,而不会引发任何异常。我尝试同时使用同步和 gthreads 作为工作线程,但结果相同。
我试图通过在 ThreadPoolExecutor 和 ProcessPoolExecutor 中包装相同的调用来查看 JAX 是否可以处理多处理和多线程,并且可以完美地工作。
在调试期间,每次我检查 JAX DeviceArray 时,应用程序都会崩溃。使用 JAX 跳过第一个计算也是如此。
任何帮助将非常感激!
python - 使用 numpy 和 jax 进行非传递子类化
我的问题很简单:
?
现在我会闲逛,所以 SE 会接受我的合理问题。
python - Jax 中的 vmap ops.index_update
我在下面有以下代码,它使用了一个简单的 for 循环。我只是想知道是否有办法 vmap 它?这是原始代码:
这是我使用 vmap 的尝试:
但我收到以下错误:
TypeError: vmap in_axes 必须是一个 int、None 或(嵌套)容器,这些类型作为叶子,但得到了 Traced<ShapedArray(float32[192,334])>with<DynamicJaxprTrace(level=0/1)>。
我有点困惑,因为范围是 int 类型,所以我不太确定发生了什么。
最后,我试图让这个小部分尽可能地优化,以获得最短的时间。
performance - 对矩阵元素求幂的两种方法的比较
我有两种对jnp = jax.numpy
. 一个直截了当的:
还有一些额外的动作:
但是,当我测试它们时:
尽管从表面上看有一些额外的开销,但第二种方法表现得更好。我运行了%timeit
一个大小为 2000 x 2000 的矩阵:
为什么会这样?
openmdao - OpenMDAO 可以与 autograd 或 jax 合作吗?
是否可以使用 autograd 或 jax 包为 OpenMDAO 显式组件生成等效的解析导数?即比有限差分更准确(或者可能比复杂步骤方法更准确或更通用?)但没有手动计算和编程分析梯度的工作?
我不是这些软件包的专家,但它们似乎就是为此目的而设计的。
python - Jaxlib pip 安装失败
从命令行,我尝试按照这个安装教程进行操作,如果可能的话,我想避免从源代码构建。目前,我不确定问题是什么。任何人都可以验证他们在尝试安装 Jaxlib 时得到相同/不同的响应吗?
出于意识,Jax 安装良好,没有任何问题,但是在单独安装的 Jaxlib 中找到了一些支持组件。
cuda - 在以下位置找不到库:/usr/local/cuda-9.0/targets/aarch64-linux/lib/libcublasLt.so.9.0
我正在尝试在 NVIDIA Jetson TX2 上安装 JAX,但遇到了相当大的问题。
我有 CUDA 9.0,它给了我以下错误:
所以我去寻找,当然那个图书馆不存在。有人对我如何安装该库有任何指示吗?我试过搜索谷歌,它似乎根本不存在。