我正在索引向量并使用JAX,但我注意到与numpy相比,在简单地索引数组时速度相当慢。例如,考虑在 JAX numpy 和普通 numpy 中制作一个基本数组:
import jax.numpy as jnp
import numpy as onp
jax_array = jnp.ones((1000,))
numpy_array = onp.ones(1000)
然后简单地在两个整数之间建立索引,对于 JAX(在 GPU 上),这给出了一个时间:
%timeit jax_array[435:852]
1000 次循环,5 次中的最佳:每个循环 1.38 毫秒
对于numpy,这给出了一个时间:
%timeit numpy_array[435:852]
1000000 次循环,5 次中的最佳:每个循环 271 ns
所以 numpy 比 JAX 快 5000 倍。当 JAX 在 CPU 上时,则
%timeit jax_array[435:852]
1000 个循环,5 个循环中的最佳:每个循环 577 µs
这么快,但仍然比 numpy 慢 2000 倍。我为此使用 Google Colab 笔记本,所以安装/CUDA 应该没有问题。
我错过了什么吗?我意识到 JAX 和 numpy 的索引是不同的,正如JAX 'sharp edges' documentation给出的那样,但我找不到任何方法来执行分配,例如
new_array = jax_array[435:852]
没有明显放缓。我无法避免索引数组,因为它在我的程序中是必需的。