1

考虑以下文件:

import jax.numpy as jnp

def test(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
    return a + b

运行mypy mypytest.py返回以下错误:

mypytest.py:4: error: Incompatible return value type (got "numpy.ndarray[Any, dtype[bool_]]", expected "jax._src.numpy.lax_numpy.ndarray")

由于某种原因,它认为添加两个jax.numpy.ndarrays 会返回一个 NumPy 数组bools。难道我做错了什么?或者这是 MyPy 或 Jax 的类型注释中的错误?

4

2 回答 2

3

至少在静态上,jnp.ndarray是一个np.ndarray修改非常少的子类

class ndarray(np.ndarray, metaclass=_ArrayMeta):
  dtype: np.dtype
  shape: Tuple[int, ...]
  size: int

  def __init__(shape, dtype=None, buffer=None, offset=0, strides=None,
               order=None):
    raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
                    " Use jax.numpy.array, or jax.numpy.zeros instead.")

因此,它继承了np.ndarray的方法类型签名。

我猜运行时行为是通过jnp.array函数实现的。除非我错过了一些存根文件或类型欺骗,否则jnp.array匹配的结果jnp.ndarray仅仅是因为jnp.array没有类型。你可以用

def foo(_: str) -> None:
   pass

foo(jnp.array(0))

通过 mypy.

所以回答你的问题,我不认为你做错了什么。这是一个错误,它可能不是他们的意思,但它实际上并不是不正确的,因为np.ndarray当你添加jnp.ndarrays 时你确实得到了 a ,因为 ajnp.ndarray是一个np.ndarray

至于为什么bools,那很可能是因为你jnp.array的 s 缺少泛型参数,并且__add__on的第一个有效重载np.ndarray

    @overload
    def __add__(self: NDArray[bool_], other: _ArrayLikeBool_co) -> NDArray[bool_]: ...  # type: ignore[misc]

所以它只是默认为bool.

于 2021-08-22T20:36:32.233 回答
1

一般来说,JAX 与 mypy 的兼容性很差,因为很难用 JAX 的转换模型来满足 mypy 的约束,它经常调用具有转换特定跟踪器值的函数,这些跟踪器值充当数组的替身(请参阅如何在 JAX 中思考: JIT Mechanics简要讨论了这种机制)。

将跟踪器类型用作数组的替代品意味着 mypy 将在转换严格类型的 JAX 函数时引发错误,因此在整个 JAX 代码库中,我们倾向于将其别名ArrayAny,并将其用作 JAX 函数的返回类型注释返回数组。

对此进行改进会很好,因为Any返回类型对于有效的类型检查不是很有用,但这只是使 mypy 与 JAX 良好配合的众多挑战中的第一个。如果你想阅读过去几年围绕这个问题的一些讨论,我会从这里开始:https ://github.com/google/jax/issues/943

同时,我的建议是Any用作 JAX 数组的类型注释。

于 2021-08-22T21:26:37.193 回答