4

假设我们有一个 3D PyTorch 张量,其中第一个维度表示batch_size,如下所示:

import torch
import torch.nn as nn
x = torch.randn(32, 100, 25)

也就是说,对于每个ix[i]是一组 100 个 25 维向量。我想为每个批次项目计算这些向量的相似度(例如,余弦相似度——但通常是任何这样的成对距离/相似度矩阵)。

也就是说,对于每个x[i]我需要计算一个[100, 100]矩阵,该矩阵将包含上述向量的成对相似性。更具体地说,该矩阵的第 (i,j) 元素应包含 (100x25) 的第 i 行和第 j 行之间的相似性(或距离)x[t],对于所有t=1, ..., batch_size

如果我使用torch.nn.CosineSimilarity(),无论dim我使用什么,结果都是[100, 25](dim=0)[32, 25]( dim=1) ,我需要一个大小的张量[32, 100, 100]。我希望torch.nn.CosineSimilarity()以这种方式工作(因为,至少对我来说,它看起来更直观) ,但事实并非如此。

可以使用下面的方法来完成吗?

torch.matmul(x, x.permute(0, 2, 1))

我想这可以给出一个距离矩阵,但是如果我需要一个任意的成对操作怎么办?我应该使用上述方法构建此操作吗?

或者也许我应该x以某种方式重复,以便我可以使用内置的torch.nn.CosineSimilarity()

谢谢你。

4

4 回答 4

6

该文档暗示输入的形状cosine_similarity必须相等,但事实并非如此。在 PyTorch 内部通过 广播torch.mul,插入带有切片(或torch.unsqueeze)的维度将为您提供所需的结果。由于上下三角形的重复计算和内存,这不是最佳的,但它很简单:

import torch
from torch.nn import functional as F
from scipy.spatial import distance

# compute once in pytorch
x = torch.randn(32, 100, 25)
y = F.cosine_similarity(x[..., None, :, :], x[..., :, None, :], dim=-1)

assert y.shape == torch.Size([32, 100, 100])

# test against scipy by iterating over each batch element
z = []
for i in range(x.shape[0]):
    slice = x[i, ...].numpy()
    z.append(torch.tensor(distance.cdist(slice, slice, metric='cosine'), dtype=torch.float32))

# convert similarity to distance and ensure they're reasonably close
assert torch.allclose(torch.stack(z), 1.0-y)

于 2020-12-29T23:53:47.420 回答
3

如果您仔细阅读文档,nn.CosineSimilaritynn.PairwiseDistance会发现它们不会计算所有成对的相似性/距离(如您所要求的),而是期望两个具有相同形状的输入,并计算所有对应点之间的相似性/距离只要。
也就是说,如果您有两组 100 个 32 维向量,这些函数将计算的是第一个集合中的第一个向量与第二个集合i中对应的第一个向量之间的相似度/距离i,结果只有 100 个相似度/距离价值观。

如果要计算所有成对距离,则需要手动计算它们。
使用torch.matmul似乎是朝着正确方向迈出的一步。

如果您正在寻找一种计算 L2 距离的有效方法,您可能会发现此答案中的方法很有用。

于 2020-03-01T10:49:04.007 回答
1

对于欧几里德距离/相似度(或更一般地任何 p 范数距离)的情况,该问题的部分答案,但不适用于余弦相似度:

使用torch.cdist,它“计算两个行向量集合的每对之间的 p-norm 距离”。

于 2021-04-19T15:54:00.690 回答
0

这应该这样做:

import torch.nn.functional as F
x = torch.randn(32, 100, 25)

# cosine similarity: normalize and multiply
cos = lambda m: F.normalize(m) @ F.normalize(m).t()
torch.stack([cos(m) for m in x])  # [32, 100, 100]

注意:这是句子转换器中的余弦相似度实现

于 2021-10-14T21:29:11.853 回答