问题标签 [jax]
For questions regarding programming in ECMAScript (JavaScript/JS) and its various dialects/implementations (excluding ActionScript). Note JavaScript is NOT the same as Java! Please include all relevant tags on your question; e.g., [node.js], [jquery], [json], [reactjs], [angular], [ember.js], [vue.js], [typescript], [svelte], etc.
python - 使用vmap(jax)对矩阵元素求和?
我试图了解 vmap 中的 in_axes 和 out_axes 选项。例如,我想对两个矩阵求和并得到相同形状的输出。
我想我分别为 X 和 Y 映射了轴 0 和 1。输出将具有与 X,Y 相同的形状。但我得到了错误,
python - jax vmap 函数中的调试数组
亲爱的 jax 专家,我需要您的帮助。
这是一个工作示例(我已经按照建议简化了我的代码,尽管我不是 jax 方面的专家,也不是 Python 方面的专家来猜测 vmap 所涉及的机制的核心是什么)
然后我的问题涉及维度演变print(f"mh_update: positions[{i-1}]:",jnp.asarray(positions[i-1]))
。我不明白为什么positions[i-1]
从维度开始n_dim
然后切换到n_chains x n_dim
?
提前感谢您的评论?
这是完整的输出:
python - JAX 仅在 jit 下的数组切片上应用函数
我正在使用 JAX,我想执行类似的操作
这不能在 下执行jit
。有没有办法用jax.ops
or做到这一点jax.lax
?我曾想过使用jax.ops.index_update(x, idx, y)
,但我找不到一种计算方法y
而不会再次遇到同样的问题。
python - 使用 CUDA 安装 JAX 时出现错误'/usr/bin/bash: line 1: realpath: command not found'
我正在尝试在我的 Windows 笔记本电脑上使用 CUDA 从源代码构建 JAX。我已经安装了 MSYS2。
我正在按照此处给出的说明进行操作
但是,我无法按照文档中的说明realpath
进行安装。每当我从 MSYS2运行时,我都会pacman -S patch realpath
收到错误消息。error: target not found: realpath
pacman -S realpath
我认为这可能是我收到错误的原因。但是,我无法弄清楚如何使用安装realpath
,pacman
因为我在任何地方都找不到它。
对此的任何帮助表示赞赏。
我在下面给出了整个错误消息。
python - 在 JAX 中使用 VJP 时有没有办法禁用前向评估?
我在我的项目中经常使用 VJP。它运行受雅可比计算约束的函数,并返回一个 primals_out 以及可调用的 vjp 函数。例如,JAX 文档中的自定义 VJP 定义如下所示:
在此示例中,我们看到使用 VJP 时需要评估前向函数。使用常规 VJP 而不是自定义 VJP 时也是如此。但是,当函数的评估成本很高并且因为我已经在我的代码中的某个地方运行了该函数时,我不希望 VJP 再评估一次该函数。
那么,有没有办法表明在计算其 VJP 时不会评估函数?
python - 在 scipy.optimize.minimize 中的 NonlinearConstraint 中指定粗麻布的问题
我无法在scipy.optimize.minimize
. 我创建了一个最小的问题来仔细检查,但我也无法让它工作。有人会碰巧知道问题是什么吗?
这是我的例子:
jax - 关于每次调用 jax.jit 的函数重新编译
我是新手jax
。当我阅读文档时,我对jit
.
在缓存部分,它说“避免在循环内调用 jax.jit。这样做有效地在每次调用时创建一个新的 f ,每次都会编译它而不是重用相同的缓存函数”。但是,运行以下代码只会产生一种打印副作用:
字符串“tracing...”只打印一次,似乎jit
不再跟踪函数。
这是故意的吗?谢谢你的帮助!
python - JAX - 区分功能的问题
我正在尝试在调用中执行蒙特卡罗模拟,然后在 Python 中计算其相对于基础资产的一阶导数,但它仍然不起作用
我不知道实现该算法的方式是否是使用“AD 方法”计算导数的最佳方式;该算法以这种方式工作:
S = 模拟一个包含所有底层证券的矩阵;对于每一行,我使用“xi = jnp.linspace”生成每个底层,并且在矩阵的每一行内,我有相同的值多次等于“number_sim”
product = 生成 BM(包含正常数的向量)后,我需要将BM的每个元素(带有 exp)与S的每一行的每个元素相乘
所以这是对算法的简短解释,我非常感谢任何管理这个问题的建议或技巧,并用 AD 方法计算导数!提前致谢
tensorflow - NCCL 操作 ncclGroupEnd() 失败:未处理的系统错误
我可以在 colab 上运行此文件vit_jax.ipynb并执行训练并运行我的实验,但是当我尝试在我的集群上复制它时,我在下面给出的训练期间遇到错误。但是,计算准确性的前向传递在我的集群上运行良好。
我的集群上有 4 个带有 CUDA10.1 版本的 GTX 1080,并使用 tensorflow==2.4.0 和 jax[cuda101]==0.2.18。我在 docker 容器内将它作为 jupyter notebook 运行。
请让我知道是否有人以前遇到过这个问题?或者有什么办法解决这个问题?
jax - 在 jax 中实现 Dense 层的“正确”或最佳方法是什么,其中每个层可能有也可能没有偏差?
例如,jax.experimental.stax
有一个像这样实现的密集层:
例如,如果我们将bias 实现为允许为None,或者params 的长度为1,那么就会对grad 的工作方式产生影响。
在这里应该瞄准的模式是什么?jax.jit
有一个static_argnums
我想可以与一些has_bias
参数一起使用的,但涉及到簿记,我相信某处必须有一些例子。