至少在静态上,jnp.ndarray
是一个np.ndarray
修改非常少的子类
class ndarray(np.ndarray, metaclass=_ArrayMeta):
dtype: np.dtype
shape: Tuple[int, ...]
size: int
def __init__(shape, dtype=None, buffer=None, offset=0, strides=None,
order=None):
raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
" Use jax.numpy.array, or jax.numpy.zeros instead.")
因此,它继承了np.ndarray
的方法类型签名。
我猜运行时行为是通过jnp.array
函数实现的。除非我错过了一些存根文件或类型欺骗,否则jnp.array
匹配的结果jnp.ndarray
仅仅是因为jnp.array
没有类型。你可以用
def foo(_: str) -> None:
pass
foo(jnp.array(0))
通过 mypy.
所以回答你的问题,我不认为你做错了什么。这是一个错误,它可能不是他们的意思,但它实际上并不是不正确的,因为np.ndarray
当你添加jnp.ndarray
s 时你确实得到了 a ,因为 ajnp.ndarray
是一个np.ndarray
。
至于为什么bool
s,那很可能是因为你jnp.array
的 s 缺少泛型参数,并且__add__
on的第一个有效重载np.ndarray
是
@overload
def __add__(self: NDArray[bool_], other: _ArrayLikeBool_co) -> NDArray[bool_]: ... # type: ignore[misc]
所以它只是默认为bool
.