3

如果您输入一个包含一般对象的数组到numpy.unique,结果将基于什么是唯一的?

我努力了:

import numpy as np

class A(object): #probably exists a nice mixin for this :P
    def __init__(self, a):
        self.a = a
    def __lt__(self, other):
        return self.a < other.a
    def __le__(self, other):
        return self.a <= other.a
    def __eq__(self, other):
        return self.a == other.a
    def __ge__(self, other):
        return self.a >= other.a
    def __gt__(self, other):
        return self.a > other.a
    def __ne__(self, other):
        return self.a != other.a
    def __repr__(self):
        return "A({})".format(self.a)
    def __str__(self):
       return self.__repr__()

np.unique(map(A, range(3)+range(3)))

返回

array([A(0), A(0), A(1), A(1), A(2), A(2)], dtype=object)

但我的意图是:

array([A(0), A(1), A(2)], dtype=object)
4

1 回答 1

4

假设重复A(2)是一个错字,我认为您只需要定义__hash__(请参阅文档):

import numpy as np
from functools import total_ordering

@total_ordering
class A(object):
    def __init__(self, a):
        self.a = a
    def __lt__(self, other):
        return self.a < other.a
    def __eq__(self, other):
        return self.a == other.a
    def __ne__(self, other):
        return self.a != other.a
    def __hash__(self):
        return hash(self.a)
    def __repr__(self):
        return "A({})".format(self.a)
    def __str__(self):
       return repr(self)

生产

>>> map(A, range(3)+range(3))
[A(0), A(1), A(2), A(0), A(1), A(2)]
>>> set(map(A, range(3)+range(3)))
set([A(0), A(1), A(2)])
>>> np.unique(map(A, range(3)+range(3)))
array([A(0), A(1), A(2)], dtype=object)

正如您所猜测的那样,我曾经在其中total_ordering减少方法的扩散。:^)

[发布后编辑以纠正缺失__ne__。]

于 2013-05-05T22:14:22.717 回答