我正在使用 JAX,我想执行类似的操作
@jax.jit
def fun(x, index):
x[:index] = other_fun(x[:index])
return x
这不能在 下执行jit
。有没有办法用jax.ops
or做到这一点jax.lax
?我曾想过使用jax.ops.index_update(x, idx, y)
,但我找不到一种计算方法y
而不会再次遇到同样的问题。
如果您的索引是静态的, @rvinas using的先前答案dynamic_slice
效果很好,但您也可以使用动态索引来完成此操作jnp.where
。例如:
import jax
import jax.numpy as jnp
def other_fun(x):
return x + 1
@jax.jit
def fun(x, index):
mask = jnp.arange(x.shape[0]) < index
return jnp.where(mask, other_fun(x), x)
x = jnp.arange(5)
print(fun(x, 3))
# [1 2 3 3 4]
您的实施中似乎存在两个问题。首先,切片生成动态形状的数组(在 jitted 代码中不允许)。其次,与 numpy 数组不同,JAX 数组是不可变的(即数组的内容不能更改)。
static_argnums
您可以通过结合和来克服这两个问题jax.lax.dynamic_update_slice
。这是一个例子:
def other_fun(x):
return x + 1
@jax.partial(jax.jit, static_argnums=(1,))
def fun(x, index):
update = other_fun(x[:index])
return jax.lax.dynamic_update_slice(x, update, (0,))
x = jnp.arange(5)
print(fun(x, 3)) # prints [1 2 3 3 4]
本质上,上面的示例static_argnums
用于指示该函数应针对不同的index
值重新编译,并创建一个具有更新值jax.lax.dynamic_update_slice
的副本。x
:len(update)