问题标签 [jax]

For questions regarding programming in ECMAScript (JavaScript/JS) and its various dialects/implementations (excluding ActionScript). Note JavaScript is NOT the same as Java! Please include all relevant tags on your question; e.g., [node.js], [jquery], [json], [reactjs], [angular], [ember.js], [vue.js], [typescript], [svelte], etc.

0 投票
1 回答
164 浏览

python - Websockets 消息仅在最后发送,而不是在使用 async / await 的实例中发送,在嵌套的 for 循环中产生

我有一个计算量很大的过程,需要几分钟才能在服务器中完成。所以我想通过 websockets 将每次迭代的结果发送给客户端。

整个应用程序有效,但我的问题是,在整个模拟完成后,所有消息都以一大块形式到达客户端。我必须在这里遗漏一些东西,因为我希望await websocket.send_json()在此过程中发送消息,而不是最后发送所有消息。

服务器 python (FastAPI)

为了完整起见,这里是基本的客户端代码。

客户

0 投票
2 回答
227 浏览

python - 为什么 Mypy 认为添加两个 Jax 数组会返回一个 numpy 数组?

考虑以下文件:

运行mypy mypytest.py返回以下错误:

由于某种原因,它认为添加两个jax.numpy.ndarrays 会返回一个 NumPy 数组bools。难道我做错了什么?或者这是 MyPy 或 Jax 的类型注释中的错误?

0 投票
1 回答
370 浏览

python - 带有差异的 Python 中的错误消息

我正在使用 Montecarlo 方法为通用看涨期权计算这些衍生品。我对这种组合导数感兴趣(关于 S 和 Sigma)。使用算法微分执行此操作,我得到一个可以在页面末尾看到的错误。什么是可能的解决方案?只是为了解释有关代码的内容,我将在下面的代码中附上用于计算“X”的公式:

在此处输入图像描述

这是错误消息:

下面的堆栈跟踪不包括 JAX 内部帧。前面是发生的原始异常,未经修改。


上述异常是以下异常的直接原因:

0 投票
1 回答
179 浏览

python - 有没有一种方法可以加快使用 JAX 索引向量的速度?

我正在索引向量并使用JAX,但我注意到与numpy相比,在简单地索引数组时速度相当慢。例如,考虑在 JAX numpy 和普通 numpy 中制作一个基本数组:

然后简单地在两个整数之间建立索引,对于 JAX(在 GPU 上),这给出了一个时间:

1000 次循环,5 次中的最佳:每个循环 1.38 毫秒

对于numpy,这给出了一个时间:

1000000 次循环,5 次中的最佳:每个循环 271 ns

所以 numpy 比 JAX 快 5000 倍。当 JAX 在 CPU 上时,则

1000 个循环,5 个循环中的最佳:每个循环 577 µs

这么快,但仍然比 numpy 慢 2000 倍。我为此使用 Google Colab 笔记本,所以安装/CUDA 应该没有问题。

我错过了什么吗?我意识到 JAX 和 numpy 的索引是不同的,正如JAX 'sharp edges' documentation给出的那样,但我找不到任何方法来执行分配,例如

没有明显放缓。我无法避免索引数组,因为它在我的程序中是必需的。

0 投票
1 回答
58 浏览

jax - 如何从俳句中的参数(pytree)中获取参数?(jax 框架)

例如,您设置了一个具有参数的模块。但是,如果您想在损失中规范化某些东西,那么模式是什么?

示例中缺少一些模式。

0 投票
2 回答
641 浏览

python - 使用 for 循环时如何减少 JAX 编译时间?

这是一个基本的例子。

当 cons 较小时,编译时间约为一分钟。缺点越大,编译时间就越长——10 分钟。我需要更高的缺点。可以做什么?从我正在阅读的内容来看,循环是原因。它们在编译时展开。有什么解决方法吗?还有 jax.fori_loop。但我不明白如何使用它。有 jax.experimental.loops 模块,但我还是无法理解它。

我对这一切都很陌生。因此,感谢所有帮助。如果您能提供一些如何使用 jax 循环的示例,那将不胜感激。

另外,什么是好的编译时间?几分钟内就可以了吗?在其中一个示例中,编译时间为 262 秒,剩余运行时间约为 0.1-0.2 秒。

运行时的任何收益都会被编译时间所掩盖。

0 投票
1 回答
437 浏览

python - Jax 矢量化:vmap 和/或 numpy.vectorize?

jax.numpy.vectorize和 和有什么区别jax.vmap?这是一个小片段

两种计算都给出相同的结果:

DeviceArray([ 1. , 0.80998397, 0.63975394, 0.4888039 , 0.35637075, 0.24149445, 0.14307144, 0.05990037, -0.00927836, -0.06574923]), dtype=float32

如何决定使用哪一个,以及在性能方面是否存在差异?

0 投票
1 回答
238 浏览

python - JAX/JIT vs Std Numpy 性能:我错在哪里?

这里是一个简单的练习,带有一个辛普森集成代码,我已经制作了它来接受几个函数来集成一组边界

我在 CPU 设备上,然后我得到一个 200 x 100 数字数组,对应于 Int(f_i, a_j,b_j) i:0-199 和 j:0-99

%timeit simps(funcN,a,b, 512)

每个循环 1.13 秒 ± 27.4 毫秒(平均值 ± 标准偏差。7 次运行,每个循环 1 个)

现在考虑以下 JAX/JIT 版本

我已经验证了这两个代码(纯 Numpy 和 JAX/JIT)给出的结果相同,因为最大相对误差约为 8. 10^-16。

现在,我得到以下时间 933 ms ± 51.4 ms 每个循环(平均值±标准偏差。7 次运行,每个循环 1 个)

这与纯 Numpy 非常接近。我是否偶然制作了一个非常有效的纯 Numpy 代码???还是我以错误的方式编码 JAX/JIT?

(nb. 使用 Google collab K80 GPU 时,每个循环的 JAX/JIT 时间下降到 7.19 毫秒,将纯 Numpy 保持在 1 秒/循环的水平)

0 投票
1 回答
112 浏览

python - 矢量化物理模拟?

我正在尝试模拟一些二维粒子。每个粒子都是一个有方向的圆。方向由二维单位向量指定。

在我的模拟的一部分中,我想计算粒子对之间的角度和每个粒子的方向的函数。这应该对每个粒子对进行。在视觉上,我想计算 $\theta_i$ 和 $\theta_j$ 的函数(参见图片链接)。

我已经计算了每个粒子对的成对位移单位向量。这是一个名为 r 的形状为 (N, N, 2) 的 numpy 数组,其中 N 是粒子的总数。我还计算了笛卡尔坐标中每个粒子的方向。这是一个称为形状方向(N,2)的numpy数组。

我已经能够编写我需要的代码作为双 for 循环。

但是,运行此代码需要很长时间,尤其是对于大型系统。有没有办法对双 for 循环进行矢量化,使其运行得更快?

0 投票
1 回答
63 浏览

python - JIT 无法改进我的 JAX 代码:我哪里错了?

这是一个简单的 JAX 代码,展示了 Metropolis 算法在解决 3 参数贝叶斯回归 pb 的实际操作中。即使在 CPU 上运行 wo JIT 编译也可以。现在我想知道为什么当关于 JIT 的 2 行被取消时,CPU(Jit 或非 JIT)和在 CPU 或 K80/Nvidia GPU 上运行的比较时间并没有真正不同?

我可能以错误/低效的方式编码吗?

然后一旦代码被调用一次就可以了

没有 JIT 的 CPU 时间(即@partial 行评论)我得到 1 分 27 秒,而使用 JIT 我得到 1 分 20 秒(两个结果都是 7 次运行的平均值)感谢您的建议。