问题标签 [jax]

For questions regarding programming in ECMAScript (JavaScript/JS) and its various dialects/implementations (excluding ActionScript). Note JavaScript is NOT the same as Java! Please include all relevant tags on your question; e.g., [node.js], [jquery], [json], [reactjs], [angular], [ember.js], [vue.js], [typescript], [svelte], etc.

0 投票
1 回答
4814 浏览

java - 使用基本身份验证在 Java 中获取和发布 API 调用

我想在GET不使用任何. 我需要使用基本身份验证。任何人都可以帮我提供一些教程链接。在谷歌中,我只在框架中找到了代码,但我没有使用. 我正在寻找调用 API 的代码POSTjavaframeworkspringSpringbasic authentication.

我必须在下面添加新url的。如果是安全的,需要什么修改是方法。我是新手,所以不太了解。authenticationcodeAPIbasic authPOSTjava

0 投票
4 回答
10076 浏览

python - 无法安装 jaxlib

我正在尝试通过在文档中找到的以下命令在我的 Windows 10 上安装 jaxlib。

点安装 jaxlib

它显示以下错误

任何人都可以帮助我,提前谢谢!

0 投票
1 回答
673 浏览

python - 计算 Jacobian x Jacobian.T 的有效方法

假设J是某个函数f关于某些参数的雅可比行列式。是否有有效的方法(在 PyTorch 或 Jax 中)拥有一个接受两个输入(x1x2)并在J(x1)*J(x2).transpose() 实例化内存中的整个J矩阵的情况下进行计算的函数?

我遇到过类似jvp(f, input, v=vjp(f, input))但不太理解的东西,也不确定是我想要的。

0 投票
0 回答
717 浏览

compilation - 使用 jax 时 XLA 的 jit 编译速度非常慢

我正在使用 Jax 做一些机器学习工作。Jax 使用 XLA 进行一些即时编译以加速,但编译本身在 CPU 上太慢了。我的情况是 CPU 只会使用一个内核来进行编译,这根本没有效率。

我找到了一些答案,如果我可以使用 GPU 进行编译,它会非常快。谁能告诉我如何使用 GPU 来完成编译部分?由于我没有对编译进行任何配置。谢谢!

问题的一些补充:我正在使用 Jax 计算 grad 和 hessian,这会使编译非常慢。代码如下:

0 投票
2 回答
458 浏览

python - 从函数中有效地填充数组

我想以我可以利用的方式从函数构造一个二维数组jax.jit

我通常使用的方法numpy是创建一个空数组,然后就地填充该数组。

为了使这项工作在jax我尝试使用jax.opt.index_update.

这运行没有错误,但是当我尝试使用@jax.jit装饰器时非常慢(至少比纯 python/numpy 版本慢一个数量级)。

从函数中填充多维数组的最佳方法是什么jax

0 投票
1 回答
905 浏览

python - JAX 中的条件更新?

在 autograd/numpy 我可以这样做:

我怎样才能在 JAX 中做同样的事情?

我尝试import numpy as onp并使用它来创建数组,但这似乎不起作用。

0 投票
1 回答
94 浏览

python - scipy stats zmap 函数的替代方案

zmap 函数的 scipy stats 模块有什么替代方法吗?我目前正在使用它来获取两个非常大的数组的 zmap 分数,这需要相当长的时间。

是否有任何库或替代品可以提高其性能?或者甚至是另一个获得 zmap 函数的作用?

您的想法和意见将不胜感激!

这是我下面的最小可重现代码:

这就是 scipy stats.zmap 在幕后所做的:

关于如何针对我的用例优化它的任何想法?我可以使用像 numba 或 JAX 这样的库来进一步提升它吗?

0 投票
2 回答
2283 浏览

python - 为什么这个函数在 JAX 和 numpy 中比较慢?

我有以下 numpy 函数,如下所示,我正在尝试使用 JAX 进行优化,但无论出于何种原因,它都比较慢。

有人可以指出我可以做些什么来提高这里的性能吗?我怀疑这与 Cg_new 发生的列表理解有关,但将其分开并不会在 JAX 中产生任何进一步的性能提升。

这是 JAX 等价物:

0 投票
0 回答
366 浏览

python - 使用 pytorch 在输入梯度上训练神经网络

我目前正在尝试使用 pytorch 训练神经网络,我尝试在输入导数上匹配输入。我想这样做是因为这确保了一个保守的向量场。(在为分子动力学中的力匹配训练神经网络时完成)这意味着:

问题是,如果我尝试更新神经网络的参数,所有参数的梯度都是 0。我确保模型正常工作;我不知道如何构建模型正在正确训练的图形。在 Jaxmd 中,可以像 [Jax Glass Training][1] 所示训练这样的模型。我也试过

但这会产生类似的结果并且没有意义。[1]:https ://colab.research.google.com/github/google/jax-md/blob/master/notebooks/neural_networks.ipynb#scrollTo=WNs8v2745Mc3

编辑:

更新了复制 pytorch 版本 1.6.0 的代码示例

0 投票
1 回答
788 浏览

python - 如何保存 JAX 训练模型的优化器状态?

我正在玩 mnist_vae 示例,但无法弄清楚如何正确保存/加载训练模型的权重。

之后,我使用 opt_update 训练模型并希望保存它。但是,我还没有找到将优化器状态保存到磁盘的任何功能。

我尝试保存参数并用它们初始化 opt_state,但并非所有信息都保存下来,结果 opt_state_1 不是原来的 opt_state。

如何正确保存我训练的模型?