我有一个想要更新的大向量。我将通过向向量中的特定元素添加偏移量来更新它。我指定了一个要更新的索引向量(称为索引向量ix
),并且对于每个索引,我指定一个要添加到该元素的值(称为值向量vals
)。如果索引向量的所有条目都是唯一的,那么以下代码就足够了:
vec = torch.zeros(4, dtype=torch.float)
ix = torch.tensor([0,2], dtype=torch.long)
vals = torch.tensor([0.2, 0.5], dtype=torch.float)
vec[ix] += vals
但是,如果 中有重复的索引,这将不起作用ix
。对于重复索引的情况,一种简单的方法如下:
for i in range(len(ix)):
vec[ix[i]] += vals[i]
但这不能很好地扩展 - 很大时非常慢ix
。有没有更快的方法来做到这一点?vals
如果有一种快速的方法来汇总具有相同索引的所有条目ix
,那么解决方案应该很容易。
更新:
我找到了一种效果很好的解决方案,如下所述。我仍然希望获得更好的解决方案的反馈。
# get unique indices
ix_unique = torch.unique(ix)
# for each unique index, get sum of all vals with that index
vals_unique = torch.stack([
torch.sum(torch.where(ix==i, vals, torch.zeros_like(vals)))
for i in ix_unique
])
# update vec
vec[ix_unique] += vals_unique