3

在 autograd/numpy 我可以这样做:

q[q<0] = 0.0

我怎样才能在 JAX 中做同样的事情?

我尝试import numpy as onp并使用它来创建数组,但这似乎不起作用。

4

1 回答 1

4

JAX 数组是不可变的,因此就地索引赋值语句不起作用。相反,jax 提供了jax.ops子模块,它提供了创建数组更新版本的功能。

这是一个 numpy 索引分配和等效的 JAX 索引更新的示例:

import numpy as np
q = np.arange(-5, 5)
q[q < 0] = 0
print(q)
# [0 0 0 0 0 0 1 2 3 4]

import jax.numpy as jnp
q = jnp.arange(-5, 5)
q = q.at[q < 0].set(0)  # NB: this does not modify the original array,
                        # but rather returns a modified copy.
print(q)
# [0 0 0 0 0 0 1 2 3 4]

请注意,在 op-by-op 模式下,JAX 版本确实会创建数组的多个副本。但是,当在 JIT 编译中使用时,XLA 通常可以融合此类操作并避免数据复制。

于 2020-10-19T18:56:55.900 回答