2

如何以与 JAX 兼容的方式(例如,使用jax.numpy)实现以下内容?

def actions(state: tuple[int, ...]) -> list[tuple[int, ...]]:
    l = []
    iterables = [range(1, i+1) for i in state]
    ns = list(range(len(iterables)))
    for i, iterable in enumerate(iterables):
        for value in iterable:
            action = tuple(value if n == i else 0 for n in ns)
            l.append(action)
    return l

>>> state = (3, 1, 2)
>>> actions(state)
[(1, 0, 0), (2, 0, 0), (3, 0, 0), (0, 1, 0), (0, 0, 1), (0, 0, 2)]
4

1 回答 1

1

与 numpy 一样,Jax 无法有效地对 Python 容器类型(如列表和元组)进行操作,因此实际上没有任何与 JAX 兼容的方式来创建具有您上面指定的确切签名的函数。

但是如果你对返回值是一个二维数组没意见,你可以做这样的事情,基于jnp.vstack

from typing import Tuple
import jax.numpy as jnp
from jax import jit, partial

@partial(jit, static_argnums=0)
def actions(state: Tuple[int, ...]) -> jnp.ndarray:
  return jnp.vstack([
    jnp.zeros((val, len(state)), int).at[:, i].set(jnp.arange(1, val + 1))
    for i, val in enumerate(state)])
>>> state = (3, 1, 2)
>>> actions(state)
DeviceArray([[1, 0, 0],
             [2, 0, 0],
             [3, 0, 0],
             [0, 1, 0],
             [0, 0, 1],
             [0, 0, 2]], dtype=int32)

请注意,由于输出数组的大小取决于 的内容statestate必须是静态量,因此元组是输入的不错选择。

于 2021-04-27T22:19:15.757 回答