问题标签 [jax]
For questions regarding programming in ECMAScript (JavaScript/JS) and its various dialects/implementations (excluding ActionScript). Note JavaScript is NOT the same as Java! Please include all relevant tags on your question; e.g., [node.js], [jquery], [json], [reactjs], [angular], [ember.js], [vue.js], [typescript], [svelte], etc.
python - Websockets 消息仅在最后发送,而不是在使用 async / await 的实例中发送,在嵌套的 for 循环中产生
我有一个计算量很大的过程,需要几分钟才能在服务器中完成。所以我想通过 websockets 将每次迭代的结果发送给客户端。
整个应用程序有效,但我的问题是,在整个模拟完成后,所有消息都以一大块形式到达客户端。我必须在这里遗漏一些东西,因为我希望await websocket.send_json()
在此过程中发送消息,而不是最后发送所有消息。
服务器 python (FastAPI)
为了完整起见,这里是基本的客户端代码。
客户
python - 为什么 Mypy 认为添加两个 Jax 数组会返回一个 numpy 数组?
考虑以下文件:
运行mypy mypytest.py
返回以下错误:
由于某种原因,它认为添加两个jax.numpy.ndarray
s 会返回一个 NumPy 数组bools。难道我做错了什么?或者这是 MyPy 或 Jax 的类型注释中的错误?
python - 有没有一种方法可以加快使用 JAX 索引向量的速度?
我正在索引向量并使用JAX,但我注意到与numpy相比,在简单地索引数组时速度相当慢。例如,考虑在 JAX numpy 和普通 numpy 中制作一个基本数组:
然后简单地在两个整数之间建立索引,对于 JAX(在 GPU 上),这给出了一个时间:
1000 次循环,5 次中的最佳:每个循环 1.38 毫秒
对于numpy,这给出了一个时间:
1000000 次循环,5 次中的最佳:每个循环 271 ns
所以 numpy 比 JAX 快 5000 倍。当 JAX 在 CPU 上时,则
1000 个循环,5 个循环中的最佳:每个循环 577 µs
这么快,但仍然比 numpy 慢 2000 倍。我为此使用 Google Colab 笔记本,所以安装/CUDA 应该没有问题。
我错过了什么吗?我意识到 JAX 和 numpy 的索引是不同的,正如JAX 'sharp edges' documentation给出的那样,但我找不到任何方法来执行分配,例如
没有明显放缓。我无法避免索引数组,因为它在我的程序中是必需的。
jax - 如何从俳句中的参数(pytree)中获取参数?(jax 框架)
例如,您设置了一个具有参数的模块。但是,如果您想在损失中规范化某些东西,那么模式是什么?
示例中缺少一些模式。
python - 使用 for 循环时如何减少 JAX 编译时间?
这是一个基本的例子。
当 cons 较小时,编译时间约为一分钟。缺点越大,编译时间就越长——10 分钟。我需要更高的缺点。可以做什么?从我正在阅读的内容来看,循环是原因。它们在编译时展开。有什么解决方法吗?还有 jax.fori_loop。但我不明白如何使用它。有 jax.experimental.loops 模块,但我还是无法理解它。
我对这一切都很陌生。因此,感谢所有帮助。如果您能提供一些如何使用 jax 循环的示例,那将不胜感激。
另外,什么是好的编译时间?几分钟内就可以了吗?在其中一个示例中,编译时间为 262 秒,剩余运行时间约为 0.1-0.2 秒。
运行时的任何收益都会被编译时间所掩盖。
python - Jax 矢量化:vmap 和/或 numpy.vectorize?
jax.numpy.vectorize
和 和有什么区别jax.vmap
?这是一个小片段
两种计算都给出相同的结果:
DeviceArray([ 1. , 0.80998397, 0.63975394, 0.4888039 , 0.35637075, 0.24149445, 0.14307144, 0.05990037, -0.00927836, -0.06574923]), dtype=float32
如何决定使用哪一个,以及在性能方面是否存在差异?
python - JAX/JIT vs Std Numpy 性能:我错在哪里?
这里是一个简单的练习,带有一个辛普森集成代码,我已经制作了它来接受几个函数来集成一组边界
我在 CPU 设备上,然后我得到一个 200 x 100 数字数组,对应于 Int(f_i, a_j,b_j) i:0-199 和 j:0-99
%timeit simps(funcN,a,b, 512)
每个循环 1.13 秒 ± 27.4 毫秒(平均值 ± 标准偏差。7 次运行,每个循环 1 个)
现在考虑以下 JAX/JIT 版本
我已经验证了这两个代码(纯 Numpy 和 JAX/JIT)给出的结果相同,因为最大相对误差约为 8. 10^-16。
现在,我得到以下时间 933 ms ± 51.4 ms 每个循环(平均值±标准偏差。7 次运行,每个循环 1 个)
这与纯 Numpy 非常接近。我是否偶然制作了一个非常有效的纯 Numpy 代码???还是我以错误的方式编码 JAX/JIT?
(nb. 使用 Google collab K80 GPU 时,每个循环的 JAX/JIT 时间下降到 7.19 毫秒,将纯 Numpy 保持在 1 秒/循环的水平)
python - 矢量化物理模拟?
我正在尝试模拟一些二维粒子。每个粒子都是一个有方向的圆。方向由二维单位向量指定。
在我的模拟的一部分中,我想计算粒子对之间的角度和每个粒子的方向的函数。这应该对每个粒子对进行。在视觉上,我想计算 $\theta_i$ 和 $\theta_j$ 的函数(参见图片链接)。
我已经计算了每个粒子对的成对位移单位向量。这是一个名为 r 的形状为 (N, N, 2) 的 numpy 数组,其中 N 是粒子的总数。我还计算了笛卡尔坐标中每个粒子的方向。这是一个称为形状方向(N,2)的numpy数组。
我已经能够编写我需要的代码作为双 for 循环。
但是,运行此代码需要很长时间,尤其是对于大型系统。有没有办法对双 for 循环进行矢量化,使其运行得更快?
python - JIT 无法改进我的 JAX 代码:我哪里错了?
这是一个简单的 JAX 代码,展示了 Metropolis 算法在解决 3 参数贝叶斯回归 pb 的实际操作中。即使在 CPU 上运行 wo JIT 编译也可以。现在我想知道为什么当关于 JIT 的 2 行被取消时,CPU(Jit 或非 JIT)和在 CPU 或 K80/Nvidia GPU 上运行的比较时间并没有真正不同?
我可能以错误/低效的方式编码吗?
然后一旦代码被调用一次就可以了
没有 JIT 的 CPU 时间(即@partial 行评论)我得到 1 分 27 秒,而使用 JIT 我得到 1 分 20 秒(两个结果都是 7 次运行的平均值)感谢您的建议。