比较元组时,比较它们的第一个元素,然后使用它们的第二个元素、它们的元素等来打破任何关系。例如,(2, "a") < (2, "b")
将评估为True
。
在这里,您将(node.val, node)
元组插入到堆中,尝试比较它们。如果节点值存在平局,它会移动到元组的第二个元素——节点本身。这些只是ListNode
实例。Python 不知道如何比较两个ListNodes
,因此出现错误。
要启用比较ListNodes
,您需要实现丰富的比较方法。一种快速的方法是简单地实现ListNode.__lt__
然后使用functools.total_ordering
装饰器:
import heapq
from functools import total_ordering
@total_ordering
class ListNode:
def __init__(self, value: float, label: str) -> None:
self.value = value
self.label = label
def __lt__(self, other: 'ListNode'):
return self.value <= other.value
def __str__(self):
return f"ListNode(label={self.label}, value={self.value})"
nodes = []
a = ListNode(5, "A")
b = ListNode(3, "B")
c = ListNode(5, "C")
heapq.heappush(nodes, a)
heapq.heappush(nodes, b)
heapq.heappush(nodes, c)
while nodes:
x = heapq.heappop(nodes)
print(x)
这里我们说比较两个ListNode
对象与比较它们的值是一样的。定义了比较后,您甚至根本不需要插入元组。您可以直接插入ListNode
对象,并依赖比较方法。