如果我创建一个包含 Numpy ndarray 的 Python 数据类,我就不能再使用自动生成__eq__
的了。
import numpy as np
@dataclass
class Instr:
foo: np.ndarray
bar: np.ndarray
arr = np.array([1])
arr2 = np.array([1, 2])
print(Instr(arr, arr) == Instr(arr2, arr2))
ValueError:具有多个元素的数组的真值不明确。使用 a.any() 或 a.all()
这是因为ndarray.__eq__
有时通过比较2 中较长的一个来返回ndarray
真值的a不同的价值观什么的。a[0]
b[0]
如何安全地比较@dataclass
持有 Numpy 数组的 es?
@dataclass
的实现__eq__
是使用生成的eval()
。它的源代码从堆栈跟踪中丢失,无法使用 查看inspect
,但它实际上使用的是元组比较,它调用 bool(foo)。
import dis
dis.dis(Instr.__eq__)
摘抄:
3 12 LOAD_FAST 0 (self) 14 LOAD_ATTR 1 (foo) 16 LOAD_FAST 0 (self) 18 LOAD_ATTR 2 (bar) 20 BUILD_TUPLE 2 22 LOAD_FAST 1 (other) 24 LOAD_ATTR 1 (foo) 26 LOAD_FAST 1 (other) 28 LOAD_ATTR 2 (bar) 30 BUILD_TUPLE 2 32 COMPARE_OP 2 (==) 34 RETURN_VALUE