如果我们对数组索引执行非整数计算(然后转换为 int() ),似乎我们仍然无法将结果用作 jit 编译的 jax 代码中的有效索引。我们如何解决这个问题?
以下是一个最小的示例。具体问题:命令 jnp.diag_indices(d) 是否可以在不向 fun() 传递额外参数的情况下工作
在木星单元中运行它:
import jax.numpy as jnp
from jax import jit
@jit
def fun(t):
d = jnp.sqrt(t.size**2)
d = jnp.array(d,int)
jnp.diag_indices(t.size) # this line works
jnp.diag_indices(d) # this line breaks. Comment it out to see that d and t.size have the same dtype=int32
return t.size, d
fun(jnp.array([1,2]))