问题标签 [jax]

For questions regarding programming in ECMAScript (JavaScript/JS) and its various dialects/implementations (excluding ActionScript). Note JavaScript is NOT the same as Java! Please include all relevant tags on your question; e.g., [node.js], [jquery], [json], [reactjs], [angular], [ember.js], [vue.js], [typescript], [svelte], etc.

0 投票
1 回答
1003 浏览

python - 自动区分可能不是紧密形式的多元函数

我正在尝试计算可能是也可能不是紧密形式的多元函数的一堆一阶导数。为了为您提供更多上下文,我正在尝试计算选项的“希腊语”。期权价格/价值取决于很多因素:现货价格、行使价、波动率和利率等。最常用的希腊语之一称为delta,它是期权价格/价值相对于股票现货价格变化的一个单位的变化。期权的价格可能没有接近形式/分析形式,尽管为了简单起见,我在这里使用了一些接近形式。实际上,可以使用蒙特卡罗模拟计算价格。关键是,我需要一种“NumPy 友好”的方式来计算某些函数的这些一阶导数。这就是我相信很多机器学习/深度学习的人可以帮助我的地方。我参加了一些机器学习的入门课程,并且知道有一个自动微分、反向传播和其他东西的世界。我在这里使用的库是 JAX,它似乎与“numpy”有一些问题,因为错误消息如下所示:

请注意,我正在使用“定价器”,这是一个由其他人编写的定价函数,这个定价函数是用 numpy 编写的,无法使用其他库编写。工作量太大了。我必须“应用”他用 numpy 编写的定价函数。

顺便说一句,我修改了从某个论坛看到的代码。在原始代码中,使用的函数是一个五变量函数。我所做的只是简单地添加一个名为“divyield”的变量,它就是行不通!非常感谢!我感谢任何帮助或指示!

0 投票
1 回答
96 浏览

python - jax中的高阶多元导数

我对如何在 jax 中计算高阶多元导数感到困惑。

例如,你如何计算 d^2f / dx dy

其中 x, y 在 R^n, n >= 1 中?

我一直在尝试jax.jvpand jax.partial,但我没有任何成功。

0 投票
1 回答
144 浏览

python-3.x - 如何解决 JAX/Python 中的 ValueError `vector::reserve`?

编辑:这里的 GitHub 问题:https ://github.com/google/jax/issues/5190

我正在尝试使用 jit 优化以下功能:

上面的例程在这里使用:

但我收到以下错误:

这里问题的根源是什么?不使用static_argnums错误消息是

具有相同的回溯。

0 投票
0 回答
195 浏览

machine-learning - 在向后传递 Google-JAX 中保存梯度

我正在使用JAX来实现一个简单的神经网络 (NN),并且我想在 NN 运行后访问并保存反向传播的梯度以供进一步分析。我可以使用 python 调试器临时访问和查看渐变(只要我不使用 jit)。但我想保存整个训练过程中的所有梯度,并在训练完成后对其进行分析。我为此使用 id_tap 和全局变量想出了一个相当老套的解决方案(参见下面的代码)。但我想知道是否有更好的解决方案不违反 JAX 的功能原则。

非常感谢!

0 投票
2 回答
1824 浏览

installation - 无法安装特定的 JAX jaxlib GPU 版本

我正在尝试安装特定版本的jaxlib以与我的 CUDA 和 cuDNN 版本一起使用。按照自述文件,我正在尝试

pip install --upgrade jax jaxlib==0.1.52+cuda101 -f https://storage.googleapis.com/jax-releases/jax_releases.html

这将返回以下错误:

ERROR: Requested jaxlib==0.1.52+cuda101 from https://storage.googleapis.com/jax-releases/cuda101/jaxlib-0.1.52%2Bcuda101-cp37-none-manylinux2010_x86_64.whl has different version in metadata: '0.1.52'

有谁知道是什么原因导致这个或如何解决这个错误?

0 投票
1 回答
328 浏览

jax - 如何使用 JIT 处理 JAX 重塑

我正在尝试按照此处的描述实现 entmax-alpha 。

这是代码。

当我使用以下代码调用它时:

我收到以下错误。

它似乎总是与重塑操作有关。我不确定为什么会发生这种情况,任何帮助将不胜感激。

要重现问题,这里是colab 笔记本

非常感谢。

0 投票
1 回答
494 浏览

python - Jax 找不到静态 argnums

这与这个问题有关。我设法充分利用了代码,除了一件奇怪的事情。

这是修改后的代码。

此代码将产生如下错误:

这是由这行代码引起的

即使我只用实体函数替换函数体,错误仍然存​​在。这是一个非常奇怪的行为。但是,让这个东西保持静态对我来说非常重要,因为它有助于展开循环。

0 投票
1 回答
142 浏览

python - 用于多个输入变量的 JAX 自定义 VJP 函数不适用于 NumPyro/HMC-NUTS

我正在尝试使用自定义 VJP(矢量雅可比积)函数作为 numpyro 中 HMC-NUTS 的模型。我能够制作一个适用于 HMC-NUTS 的单变量函数,如下所示:

在这里,我手动定义了 h(x)=sin(x)。然后,我做了一个测试数据

测试数据

在这种情况下,我能够在 NumPyro 中执行 HMC-NUTS

有用。

但是,如果我将多变量函数定义为,

然后执行 HMC-NUTS 作为

然后我得到一个错误

我怀疑我的函数中的输出形状是错误的。但是,经过各种尝试改变形状后,我无法弄清楚出了什么问题。

0 投票
1 回答
1261 浏览

python - 使用 vmap 时,Jax 不支持不可散列的静态参数

这与这个问题有关。经过一番工作,我设法将其更改为最后一个错误。代码现在看起来像这样。

这导致(我希望)是最终错误,即

这很奇怪,因为我认为我已经标记了每一个地方axis似乎都是静态的,但它仍然告诉我它是被追踪的。

0 投票
1 回答
275 浏览

python - 从 JAX 中的多元正态分布采样会产生类型错误

我正在尝试使用 JAX 从多元正态分布生成样本:

但是,当我运行代码时,出现以下错误:

我不确定问题是什么,因为相同的语法适用于 Numpy 中的等效函数