这是一个很好的问题,我自己偶然发现了几次。最简单的答案是,不能保证torch.argmax
(或torch.max(x, dim=k)
,在指定 dim 时也返回索引)将始终返回相同的索引。相反,它将返回任何有效索引到 argmax 值,可能是随机的。正如官方论坛中的这个线程所讨论的,这被认为是期望的行为。(我知道我不久前读过的另一个线程使这一点更加明确,但我再也找不到它了)。
话虽如此,由于这种行为对我的用例来说是不可接受的,所以我编写了以下函数来查找最左边和最右边的索引(请注意,这condition
是您传入的函数对象):
def __consistent_args(input, condition, indices):
assert len(input.shape) == 2, 'only works for batch x dim tensors along the dim axis'
mask = condition(input).float() * indices.unsqueeze(0).expand_as(input)
return torch.argmax(mask, dim=1)
def consistent_find_leftmost(input, condition):
indices = torch.arange(input.size(1), 0, -1, dtype=torch.float, device=input.device)
return __consistent_args(input, condition, indices)
def consistent_find_rightmost(input, condition):
indices = torch.arange(0, input.size(1), 1, dtype=torch.float, device=input.device)
return __consistent_args(input, condition, indices)
# one example:
consistent_find_leftmost(torch.arange(10).unsqueeze(0), lambda x: x>5)
# will return:
# tensor([6])
希望他们会有所帮助!(哦,如果您有更好的实现,请告诉我)