27

我使用 Python 的unittest模块,想检查两个复杂的数据结构是否相等。对象可以是具有各种值的字典列表:数字、字符串、Python 容器(列表/元组/字典)和numpy数组。后者是问这个问题的原因,因为我不能只是做

self.assertEqual(big_struct1, big_struct2)

因为它产生一个

ValueError: The truth value of an array with more than one element is ambiguous.
Use a.any() or a.all()

我想我需要为此编写自己的平等测试。它应该适用于任意结构。我目前的想法是一个递归函数:

  • 尝试将当前“节点”arg1与 的相应节点进行直接比较arg2
  • 如果没有引发异常,则继续(“终端”节点/叶子也在此处处理);
  • 如果ValueError被捕获,则继续深入,直到找到一个numpy.array;
  • 比较数组(例如像这样)。

跟踪两个结构的“对应”节点似乎有点问题,但也许zip我在这里只需要。

问题是:这种方法是否有更好(更简单)的替代方案?也许numpy为此提供一些工具?如果没有建议替代方案,我将实施这个想法(除非我有更好的想法)并作为答案发布。

PS我有一种模糊的感觉,我可能已经看到了一个解决这个问题的问题,但我现在找不到它。

PPS 另一种方法是遍历结构并将所有numpy.arrays 转换为列表的函数,但这更容易实现吗?对我来说似乎一样。


编辑:子类numpy.ndarray化听起来很有希望,但显然我没有将比较的双方硬编码到测试中。但是,其中一个确实是硬编码的,所以我可以:

  • numpy.array用;的自定义子类填充它
  • 更改isinstance(other, SaneEqualityArray)jterraceisinstance(other, np.ndarray)答案
  • 在比较中始终将其用作 LHS。

我在这方面的问题是:

  1. 它会起作用吗(我的意思是,这对我来说听起来不错,但可能无法正确处理一些棘手的边缘情况)?如我所料,我的自定义对象在递归相等检查中总是以 LHS 结尾吗?
  2. 同样,有没有更好的方法(假设我得到了至少一个带有真实numpy数组的结构)。

编辑2:我试过了,这个答案中显示了(看似)工作的实现。

4

7 回答 7

13

本来可以评论的,但是太长了...

有趣的事实是,你不能用它==来测试数组是否相同,我建议你np.testing.assert_array_equal改用它。

  1. 检查 dtype、shape 等,
  2. 这不会因为(float('nan') == float('nan')) == False(正常的 python 序列有时会==以一种更有趣的方式忽略它)而失败,因为它使用which 进行(对于 NaN 不正确的)快速检查(用于测试当然是完美的).. .PyObject_RichCompareBoolis
  3. 还有一个原因是,如果您进行实际计算并且您通常想要几乎assert_allclose相同的值,浮点相等可能会变得非常棘手,因为这些值可能取决于硬件或可能是随机的,具体取决于您对它们的处理方式。

如果您想要这种疯狂嵌套的东西,我几乎建议您尝试使用 pickle 对其进行序列化,但这过于严格(然后第 3 点当然完全被破坏了),例如,您的数组的内存布局无关紧要,但对它很重要序列化。

于 2013-01-10T02:02:22.667 回答
9

assertEqual函数将调用__eq__对象的方法,该方法应针对复杂数据类型进行递归。例外是 numpy,它没有健全的__eq__方法。使用此问题中的 numpy 子类,您可以恢复平等行为的理智:

import copy
import numpy
import unittest

class SaneEqualityArray(numpy.ndarray):
    def __eq__(self, other):
        return (isinstance(other, SaneEqualityArray) and
                self.shape == other.shape and
                numpy.ndarray.__eq__(self, other).all())

class TestAsserts(unittest.TestCase):

    def testAssert(self):
        tests = [
            [1, 2],
            {'foo': 2},
            [2, 'foo', {'d': 4}],
            SaneEqualityArray([1, 2]),
            {'foo': {'hey': SaneEqualityArray([2, 3])}},
            [{'foo': SaneEqualityArray([3, 4]), 'd': {'doo': 3}},
             SaneEqualityArray([5, 6]), 34]
        ]
        for t in tests:
            self.assertEqual(t, copy.deepcopy(t))

if __name__ == '__main__':
    unittest.main()

该测试通过。

于 2013-01-10T01:07:28.040 回答
7

所以jterrace说明的想法似乎对我有用,只需稍作修改:

class SaneEqualityArray(np.ndarray):
    def __eq__(self, other):
        return (isinstance(other, np.ndarray) and self.shape == other.shape and 
            np.allclose(self, other))

就像我说的,包含这些对象的容器应该在相等检查的左侧。我从现有的 s 创建SaneEqualityArray对象,numpy.ndarray如下所示:

SaneEqualityArray(my_array.shape, my_array.dtype, my_array)

按照ndarray构造函数签名:

ndarray(shape, dtype=float, buffer=None, offset=0,
        strides=None, order=None)

此类在测试套件中定义,仅用于测试目的。相等检查的 RHS 是被测试函数返回的一个实际对象,包含真实numpy.ndarray对象。

PS 感谢迄今为止发布的两个答案的作者,他们都非常有帮助。如果有人发现这种方法有任何问题,我会很感激您的反馈。

于 2013-01-11T11:09:38.110 回答
2

我将定义我自己的 assertNumpyArraysEqual() 方法,该方法明确地进行您想要使用的比较。这样,您的生产代码不会改变,但您仍然可以在单元测试中做出合理的断言。确保在包含的模块中定义它,__unittest = True以便它不会包含在堆栈跟踪中:

import numpy
__unittest = True

def assertNumpyArraysEqual(self, other):
    if self.shape != other.shape:
        raise AssertionError("Shapes don't match")
    if not numpy.allclose(self, other)
        raise AssertionError("Elements don't match!")
于 2013-03-14T01:13:39.003 回答
1

检查numpy.testing.assert_almost_equal哪个“如果两个项目不等于所需的精度,则引发 AssertionError”,例如:

 import numpy.testing as npt
 npt.assert_almost_equal(np.array([1.0,2.3333333333333]),
                         np.array([1.0,2.33333334]), decimal=9)
于 2016-08-19T10:53:07.210 回答
1

我遇到了同样的问题,并开发了一个函数来基于为对象创建一个固定的哈希来比较相等性。这具有额外的优势,您可以通过将对象的哈希值与代码中支持的固定值进行比较来测试对象是否符合预期。

代码(一个独立的 python 文件,在这里)。有两个函数:fixed_hash_eq,它可以解决您的问题,以及compute_fixed_hash,它从结构中生成哈希。 测试在这里

这是一个测试:

obj1 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj2 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj3 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
obj3[2]['b'][4] = 0
assert fixed_hash_eq(obj1, obj2)
assert not fixed_hash_eq(obj1, obj3)
于 2017-11-01T08:23:57.200 回答
0

在@dbw 的基础上(感谢),插入到测试用例子类中的以下方法对我来说效果很好:

 def assertNumpyArraysEqual(self,this,that,msg=''):
    '''
    modified from http://stackoverflow.com/a/15399475/5459638
    '''
    if this.shape != that.shape:
        raise AssertionError("Shapes don't match")
    if not np.allclose(this,that):
        raise AssertionError("Elements don't match!")

self.assertNumpyArraysEqual(this,that)在我的测试用例方法中调用了它,并且像一个魅力一样工作。

于 2016-03-30T21:55:02.120 回答