1

我有一个 size 的索引张量(2, 3)

>>> index = torch.empty(6).random_(0,8).view(2,3)
tensor([[6., 3., 2.],
        [3., 4., 7.]])

和一个 size 的值张量(2, 8)

>>> value = torch.zeros(2,8)
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

我想通过索引设置元素value1**dim=-1输出应该是这样的:

>>> output
tensor([[0., 0., 1., 1., 0., 0., 1., 0.],
        [0., 0., 0., 1., 1., 0., 0., 1.]])

我试过value[range(2), index] = 1了,但它触发了一个错误。我也试过torch.index_fill,但它不接受批量索引。torch.scatter需要创建一个大小为 的额外张量2*81这会消耗不必要的内存和时间。

4

1 回答 1

2

您实际上可以torch.Tensor.scatter_通过设置value( int ) 选项而不是src选项 ( Tensor ) 来使用。

>>> value.scatter_(dim=-1, index=index.long(), value=1)

>>> value
tensor([[0., 0., 1., 1., 0., 0., 1., 0.],
        [0., 0., 0., 1., 1., 0., 0., 1.]])

确保它indexint64类型。

于 2021-08-29T10:17:26.023 回答