一种选择
是使用torch.bmm()
which 正是这样做的(docs)。
它采用形状为 (b, n, m) 和 (b, m, p) 的张量,并返回形状为 (b, n, p) 的批量矩阵乘法。
(我假设您输入了 BXL 的结果,因为 1 XD 和 DXL 的矩阵乘法的形状为 1 XL 而不是 1 XD)。
在你的情况下:
import torch
B, L, D = 32, 10, 512
a = torch.randn(B, 1, D) #shape (B X 1 X D)
b = torch.randn(B, L, D) #shape (B X L X D)
b = b.transpose(1,2) #shape (B X D X L)
result = torch.bmm(a, b)
result = result.squeeze()
print(result.shape)
>>> torch.Size([32, 10])
或者
您可以使用torch.einsum()
,在我看来,它更紧凑但可读性较差:
import torch
B, L, D = 32, 10, 512
a = torch.randn(B, 1, D)
b = torch.randn(B, L, D)
result = torch.einsum('abc, adc->ad', a, b)
print(result.shape)
>>> torch.Size([32, 10])
最后的挤压是为了使您的结果为形状 (32, 10) 而不是形状 (32, 1, 10)。