问题标签 [jax]
For questions regarding programming in ECMAScript (JavaScript/JS) and its various dialects/implementations (excluding ActionScript). Note JavaScript is NOT the same as Java! Please include all relevant tags on your question; e.g., [node.js], [jquery], [json], [reactjs], [angular], [ember.js], [vue.js], [typescript], [svelte], etc.
java - 使用基本身份验证在 Java 中获取和发布 API 调用
我想在GET
不使用任何. 我需要使用基本身份验证。任何人都可以帮我提供一些教程链接。在谷歌中,我只在框架中找到了代码,但我没有使用. 我正在寻找调用 API 的代码POST
java
framework
spring
Spring
basic authentication.
我必须在下面添加新url
的。如果是安全的,需要什么修改是方法。我是新手,所以不太了解。authentication
code
API
basic auth
POST
java
python - 无法安装 jaxlib
我正在尝试通过在文档中找到的以下命令在我的 Windows 10 上安装 jaxlib。
点安装 jaxlib
它显示以下错误
任何人都可以帮助我,提前谢谢!
python - 计算 Jacobian x Jacobian.T 的有效方法
假设J
是某个函数f
关于某些参数的雅可比行列式。是否有有效的方法(在 PyTorch 或 Jax 中)拥有一个接受两个输入(x1
和x2
)并在J(x1)*J(x2).transpose()
不实例化内存中的整个J
矩阵的情况下进行计算的函数?
我遇到过类似jvp(f, input, v=vjp(f, input))
但不太理解的东西,也不确定是我想要的。
compilation - 使用 jax 时 XLA 的 jit 编译速度非常慢
我正在使用 Jax 做一些机器学习工作。Jax 使用 XLA 进行一些即时编译以加速,但编译本身在 CPU 上太慢了。我的情况是 CPU 只会使用一个内核来进行编译,这根本没有效率。
我找到了一些答案,如果我可以使用 GPU 进行编译,它会非常快。谁能告诉我如何使用 GPU 来完成编译部分?由于我没有对编译进行任何配置。谢谢!
问题的一些补充:我正在使用 Jax 计算 grad 和 hessian,这会使编译非常慢。代码如下:
python - 从函数中有效地填充数组
我想以我可以利用的方式从函数构造一个二维数组jax.jit
。
我通常使用的方法numpy
是创建一个空数组,然后就地填充该数组。
为了使这项工作在jax
我尝试使用jax.opt.index_update
.
这运行没有错误,但是当我尝试使用@jax.jit
装饰器时非常慢(至少比纯 python/numpy 版本慢一个数量级)。
从函数中填充多维数组的最佳方法是什么jax
?
python - JAX 中的条件更新?
在 autograd/numpy 我可以这样做:
我怎样才能在 JAX 中做同样的事情?
我尝试import numpy as onp
并使用它来创建数组,但这似乎不起作用。
python - scipy stats zmap 函数的替代方案
zmap 函数的 scipy stats 模块有什么替代方法吗?我目前正在使用它来获取两个非常大的数组的 zmap 分数,这需要相当长的时间。
是否有任何库或替代品可以提高其性能?或者甚至是另一个获得 zmap 函数的作用?
您的想法和意见将不胜感激!
这是我下面的最小可重现代码:
这就是 scipy stats.zmap 在幕后所做的:
关于如何针对我的用例优化它的任何想法?我可以使用像 numba 或 JAX 这样的库来进一步提升它吗?
python - 为什么这个函数在 JAX 和 numpy 中比较慢?
我有以下 numpy 函数,如下所示,我正在尝试使用 JAX 进行优化,但无论出于何种原因,它都比较慢。
有人可以指出我可以做些什么来提高这里的性能吗?我怀疑这与 Cg_new 发生的列表理解有关,但将其分开并不会在 JAX 中产生任何进一步的性能提升。
这是 JAX 等价物:
python - 使用 pytorch 在输入梯度上训练神经网络
我目前正在尝试使用 pytorch 训练神经网络,我尝试在输入导数上匹配输入。我想这样做是因为这确保了一个保守的向量场。(在为分子动力学中的力匹配训练神经网络时完成)这意味着:
问题是,如果我尝试更新神经网络的参数,所有参数的梯度都是 0。我确保模型正常工作;我不知道如何构建模型正在正确训练的图形。在 Jaxmd 中,可以像 [Jax Glass Training][1] 所示训练这样的模型。我也试过
但这会产生类似的结果并且没有意义。[1]:https ://colab.research.google.com/github/google/jax-md/blob/master/notebooks/neural_networks.ipynb#scrollTo=WNs8v2745Mc3
编辑:
更新了复制 pytorch 版本 1.6.0 的代码示例
python - 如何保存 JAX 训练模型的优化器状态?
我正在玩 mnist_vae 示例,但无法弄清楚如何正确保存/加载训练模型的权重。
之后,我使用 opt_update 训练模型并希望保存它。但是,我还没有找到将优化器状态保存到磁盘的任何功能。
我尝试保存参数并用它们初始化 opt_state,但并非所有信息都保存下来,结果 opt_state_1 不是原来的 opt_state。
如何正确保存我训练的模型?