我正在使用 JAX,我想执行类似的操作
@jax.jit
def fun(x, index):
x[:index] = other_fun(x[:index])
return x
这不能在 下执行jit。有没有办法用jax.opsor做到这一点jax.lax?我曾想过使用jax.ops.index_update(x, idx, y),但我找不到一种计算方法y而不会再次遇到同样的问题。
如果您的索引是静态的, @rvinas using的先前答案dynamic_slice效果很好,但您也可以使用动态索引来完成此操作jnp.where。例如:
import jax
import jax.numpy as jnp
def other_fun(x):
return x + 1
@jax.jit
def fun(x, index):
mask = jnp.arange(x.shape[0]) < index
return jnp.where(mask, other_fun(x), x)
x = jnp.arange(5)
print(fun(x, 3))
# [1 2 3 3 4]
您的实施中似乎存在两个问题。首先,切片生成动态形状的数组(在 jitted 代码中不允许)。其次,与 numpy 数组不同,JAX 数组是不可变的(即数组的内容不能更改)。
static_argnums您可以通过结合和来克服这两个问题jax.lax.dynamic_update_slice。这是一个例子:
def other_fun(x):
return x + 1
@jax.partial(jax.jit, static_argnums=(1,))
def fun(x, index):
update = other_fun(x[:index])
return jax.lax.dynamic_update_slice(x, update, (0,))
x = jnp.arange(5)
print(fun(x, 3)) # prints [1 2 3 3 4]
本质上,上面的示例static_argnums用于指示该函数应针对不同的index值重新编译,并创建一个具有更新值jax.lax.dynamic_update_slice的副本。x:len(update)