3

我的问题很简单:

>>> isinstance(x, jax.numpy.ndarray)
True
>>> issubclass(jax.numpy.ndarray, numpy.ndarray)
True
>>> isinstance(x, numpy.ndarray)
False

?

现在我会闲逛,所以 SE 会接受我的合理问题。

4

1 回答 1

4

出现这种情况的原因是因为jax.numpy.ndarray使用元类覆盖了实例检查:

class _ArrayMeta(type(np.ndarray)):  # type: ignore
  """Metaclass for overriding ndarray isinstance checks."""

  def __instancecheck__(self, instance):
    try:
      return isinstance(instance.aval, _arraylike_types)
    except AttributeError:
      return isinstance(instance, _arraylike_types)

class ndarray(np.ndarray, metaclass=_ArrayMeta):
  dtype: np.dtype
  shape: Tuple[int, ...]
  size: int

  def __init__(shape, dtype=None, buffer=None, offset=0, strides=None,
               order=None):
    raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly."
                    " Use jax.numpy.array, or jax.numpy.zeros instead.")

查看源代码

您的代码返回它所做的事情的原因是因为您有一个x值,它不是 的实例numpy.ndarray,但此__instancecheck__方法返回 true。

为什么在 JAX 中有这种诡计?好吧,出于 JIT 编译、自动微分和其他转换的目的,JAX 使用称为跟踪器的替代对象,这些对象看起来和行为都像一个数组,尽管实际上并不是一个数组。这种对实例检查的覆盖是 JAX 用来使这种跟踪工作的技巧之一。

于 2020-11-02T23:55:49.363 回答