3

我正在使用 JAX,我想执行类似的操作

@jax.jit
def fun(x, index):
    x[:index] = other_fun(x[:index])
    return x

这不能在 下执行jit。有没有办法用jax.opsor做到这一点jax.lax?我曾想过使用jax.ops.index_update(x, idx, y),但我找不到一种计算方法y而不会再次遇到同样的问题。

4

2 回答 2

4

如果您的索引是静态的, @rvinas using的先前答案dynamic_slice效果很好,但您也可以使用动态索引来完成此操作jnp.where。例如:

import jax
import jax.numpy as jnp

def other_fun(x):
    return x + 1

@jax.jit
def fun(x, index):
  mask = jnp.arange(x.shape[0]) < index
  return jnp.where(mask, other_fun(x), x)

x = jnp.arange(5)
print(fun(x, 3))
# [1 2 3 3 4]
于 2021-07-22T14:51:59.750 回答
1

您的实施中似乎存在两个问题。首先,切片生成动态形状的数组(在 jitted 代码中不允许)。其次,与 numpy 数组不同,JAX 数组是不可变的(即数组的内容不能更改)。

static_argnums您可以通过结合和来克服这两个问题jax.lax.dynamic_update_slice。这是一个例子:

def other_fun(x):
    return x + 1

@jax.partial(jax.jit, static_argnums=(1,))
def fun(x, index):
    update = other_fun(x[:index])
    return jax.lax.dynamic_update_slice(x, update, (0,))

x = jnp.arange(5)
print(fun(x, 3))  # prints [1 2 3 3 4]

本质上,上面的示例static_argnums用于指示该函数应针对不同的index值重新编译,并创建一个具有更新值jax.lax.dynamic_update_slice的副本。x:len(update)

于 2021-07-17T18:30:42.557 回答