-1

我有一个形状为 n×2 的 numpy 数组,一堆长度为 2 的元组,我想将其转移到 SortedList。所以目标是创建一个长度为 2 的整数元组的 SortedList。

问题是 SortedList 的构造函数检查每个条目的真值。这适用于一维数组:

In [1]: import numpy as np
In [2]: from sortedcontainers import SortedList
In [3]: a = np.array([1,2,3,4])
In [4]: SortedList(a)
Out[4]: SortedList([1, 2, 3, 4], load=1000)

但是对于二维来说,当每个条目都是数组时,没有明确的真值,SortedList 是不合作的:

In [5]: a.resize(2,2)
In [6]: a
Out[6]: 
array([[1, 2],
       [3, 4]])

In [7]: SortedList(a)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-7-7a4b2693bb52> in <module>()
----> 1 SortedList(a)

/home/me/miniconda3/envs/env/lib/python3.6/site-packages/sortedcontainers/sortedlist.py in __init__(self, iterable, load)
     81 
     82         if iterable is not None:
---> 83             self._update(iterable)
     84 
     85     def __new__(cls, iterable=None, key=None, load=1000):

/home/me/miniconda3/envs/env/lib/python3.6/site-packages/sortedcontainers/sortedlist.py in update(self, iterable)
    176         _lists = self._lists
    177         _maxes = self._maxes
--> 178         values = sorted(iterable)
    179 
    180         if _maxes:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

我目前的解决方法是手动将每一行转换为一个元组:

sl = SortedList()
for t in np_array:
    x, y = t
    sl.add((x,y))

但是,该解决方案提供了一些改进空间。有谁知道如何在不将所有数组显式解包成元组的情况下解决这个问题?

4

1 回答 1

1

问题不在于正在检查数组的真值,而在于正在对它们进行比较以便对它们进行排序。如果对数组使用比较运算符,则会返回数组:

>>> import numpy as np
>>> np.array([1, 4]) < np.array([2, 3])
array([ True, False], dtype=bool)

这个得到的布尔数组实际上是其真值正在被检查的数组sorted

另一方面,元组(或列表)的相同操作将逐个元素进行比较并返回单个布尔值:

>>> (1, 4) < (2, 3)
True
>>> (1, 4) < (1, 3)
False

因此,当SortedList尝试在数组sorted序列上使用时numpy,它无法进行比较,因为它需要比较运算符返回单个布尔值。

一种抽象的方法是创建一个新的数组类,该类实现比较运算符,如__eq____lt____gt__等,以重现元组的排序行为。具有讽刺意味的是,最简单的方法是将底层数组转换为元组,例如:

class SortableArray(object):

    def __init__(self, seq):
        self._values = np.array(seq)

    def __eq__(self, other):
        return tuple(self._values) == tuple(other._values)
        # or:
        # return np.all(self._values == other._values)

    def __lt__(self, other):
        return tuple(self._values) < tuple(other._values)

    def __gt__(self, other):
        return tuple(self._values) > tuple(other._values)

    def __le__(self, other):
        return tuple(self._values) <= tuple(other._values)

    def __ge__(self, other):
        return tuple(self._values) >= tuple(other._values)

    def __str__(self):
        return str(self._values)

    def __repr__(self):
        return repr(self._values)

使用此实现,您现在可以对SortableArray对象列表进行排序:

In [4]: ar1 = SortableArray([1, 3])

In [5]: ar2 = SortableArray([1, 4])

In [6]: ar3 = SortableArray([1, 3])

In [7]: ar4 = SortableArray([4, 5])

In [8]: ar5 = SortableArray([0, 3])

In [9]: lst1 = [ar1, ar2, ar3, ar4, ar5]

In [10]: lst1
Out[10]: [array([1, 3]), array([1, 4]), array([1, 3]), array([4, 5]), array([0, 3])]

In [11]: sorted(lst1)
Out[11]: [array([0, 3]), array([1, 3]), array([1, 3]), array([1, 4]), array([4, 5])]

对于您需要的东西,这可能有点矫枉过正,但这是一种方法。在任何情况下,您都不会sorted在比较时不返回单个布尔值的对象序列上使用。

如果您所追求的只是避免 for 循环,则可以将其替换为列表理解(即SortedList([tuple(row) for row in np_array]))。

于 2017-12-01T20:41:40.550 回答