0

我正在索引向量并使用JAX,但我注意到与numpy相比,在简单地索引数组时速度相当慢。例如,考虑在 JAX numpy 和普通 numpy 中制作一个基本数组:

import jax.numpy as jnp
import numpy as onp 
jax_array = jnp.ones((1000,))
numpy_array = onp.ones(1000)

然后简单地在两个整数之间建立索引,对于 JAX(在 GPU 上),这给出了一个时间:

%timeit jax_array[435:852]

1000 次循环,5 次中的最佳:每个循环 1.38 毫秒

对于numpy,这给出了一个时间:

%timeit numpy_array[435:852]

1000000 次循环,5 次中的最佳:每个循环 271 ns

所以 numpy 比 JAX 快 5000 倍。当 JAX 在 CPU 上时,则

%timeit jax_array[435:852]

1000 个循环,5 个循环中的最佳:每个循环 577 µs

这么快,但仍然比 numpy 慢 2000 倍。我为此使用 Google Colab 笔记本,所以安装/CUDA 应该没有问题。

我错过了什么吗?我意识到 JAX 和 numpy 的索引是不同的,正如JAX 'sharp edges' documentation给出的那样,但我找不到任何方法来执行分配,例如

new_array = jax_array[435:852]

没有明显放缓。我无法避免索引数组,因为它在我的程序中是必需的。

4

1 回答 1

2

简短的回答:要在 JAX 中加快速度,请使用jit.

长答案:

您通常应该期望在 op-by-op 模式下使用 JAX 的单个操作比 numpy 中的类似操作慢。这是因为 JAX 执行有一些固定的 per-python-function-call 开销,涉及将编译推送到 XLA。

甚至像索引这样看似简单的操作也是根据多个 XLA 操作来实现的,这些操作(在 JIT 之外)每个都会增加自己的调用开销。您可以使用转换查看此序列,make_jaxpr以检查函数如何以原始操作表示:

from jax import make_jaxpr
f = lambda x: x[435:852]
make_jaxpr(f)(jax_array)
# { lambda  ; a.
#   let b = broadcast_in_dim[ broadcast_dimensions=(  )
#                             shape=(1,) ] 435
#       c = gather[ dimension_numbers=GatherDimensionNumbers(offset_dims=(0,), collapsed_slice_dims=(), start_index_map=(0,))
#                   indices_are_sorted=True
#                   slice_sizes=(417,)
#                   unique_indices=True ] a b
#       d = broadcast_in_dim[ broadcast_dimensions=(0,)
#                             shape=(417,) ] c
#   in (d,) }

(有关如何阅读本文的信息,请参阅了解 Jaxprs )。

JAX 优于 numpy 的地方不在于单个小操作(其中 JAX 调度开销占主导地位),而在于通过jit转换编译的操作序列。因此,例如,比较 JIT 编译和非 JIT 编译的索引版本:

%timeit f(jax_array).block_until_ready()
# 1000 loops, best of 5: 612 µs per loop

f_jit = jit(f)
f_jit(jax_array)  # trigger compilation
%timeit f_jit(jax_array).block_until_ready()
# 100000 loops, best of 5: 4.34 µs per loop

(请注意,由于 JAX 的异步调度block_until_ready(),需要准确的微基准测试)

JIT 编译此代码可提供 150 倍的加速。由于 JAX 的几毫秒调度开销,它仍然不如 numpy 快,但是使用 JIT,开销只会产生一次。而且,当您从微基准转移到更复杂的实际计算序列时,那几毫秒将不再占主导地位,XLA 编译器提供的优化可以使 JAX 比等效的 numpy 计算快得多。

于 2021-08-27T15:51:22.593 回答