我正在关注Pytorch seq2seq 教程,它torch.bmm
的使用方法如下:
attn_applied = torch.bmm(attn_weights.unsqueeze(0),
encoder_outputs.unsqueeze(0))
我理解为什么我们需要将注意力权重和编码器输出相乘。
我不太明白的是我们bmm
在这里需要方法的原因。
torch.bmm
文件说
执行存储在 batch1 和 batch2 中的矩阵的批处理矩阵乘积。
batch1 和 batch2 必须是 3-D 张量,每个张量都包含相同数量的矩阵。
如果batch1是(b×n×m)张量,batch2是(b×m×p)张量,out是(b×n×p)张量。