以下 numpy 代码非常好:
arr = np.arange(50)
print(arr.shape) # (50,)
indices = np.zeros((30,), dtype=int)
print(indices.shape) # (30,)
arr[indices]
迁移到 jax 后它也可以工作:
import jax.numpy as jnp
arr = jnp.arange(50)
print(arr.shape) # (50,)
indices = jnp.zeros((30,), dtype=int)
print(indices.shape) # (30,)
arr[indices]
现在让我们尝试混合使用 numpy 和 jax:
arr = np.arange(50)
print(arr.shape) # (50,)
indices = jnp.zeros((30,), dtype=int)
print(indices.shape) # (30,)
arr[indices]
这会产生以下错误:
IndexError: too many indices for array: array is 1-dimensional, but 30 were indexed
如果不支持使用 jax 数组索引到 numpy 数组,那对我来说很好。但是错误信息似乎是错误的。事情变得更加混乱。如果稍微改变形状,代码就可以正常工作。在下面的示例中,我只编辑了从 (30,) 到 (40,) 的索引形状。没有更多错误消息:
arr = np.arange(50)
print(arr.shape) # (50,)
indices = jnp.zeros((40,), dtype=int)
print(indices.shape) # (40,)
arr[indices]
我在 cpu 上运行 jax 版本“0.2.12”。这里发生了什么?