0

RuntimeError: bool value of Tensor with more than one value is ambiguous在这段代码中如何避免?

import torch
import heapq

h = []
heapq.heappush(h, (1, torch.Tensor([[1,2]])))
heapq.heappush(h, (1, torch.Tensor([[3,4]])))

这是因为元组之间的比较在第一个元素相等时比较第二个元素

4

1 回答 1

0

有必要防止heapq在发现重复优先级时尝试比较元组的第二个元素,只需要<为我的元素重新定义运算符即可。

import torch
import heapq

class HeapItem:
    def __init__(self, p, t):
        self.p = p
        self.t = t

    def __lt__(self, other):
        return self.p < other.p

h = []
heapq.heappush(h, HeapItem(1, torch.Tensor([[1,2]])))
heapq.heappush(h, HeapItem(1, torch.Tensor([[3,4]])))
于 2019-06-03T23:20:05.827 回答