11

我正在关注Pytorch seq2seq 教程,它torch.bmm的使用方法如下:

attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                         encoder_outputs.unsqueeze(0))

我理解为什么我们需要将注意力权重和编码器输出相乘。

我不太明白的是我们bmm在这里需要方法的原因。 torch.bmm文件说

执行存储在 batch1 和 batch2 中的矩阵的批处理矩阵乘积。

batch1 和 batch2 必须是 3-D 张量,每个张量都包含相同数量的矩阵。

如果batch1是(b×n×m)张量,batch2是(b×m×p)张量,out是(b×n×p)张量。

在此处输入图像描述

4

3 回答 3

13

在 seq2seq 模型中,编码器将输入序列编码为小批量。例如,输入是B x S x d其中 B 是批量大小,S 是最大序列长度,d 是词嵌入维度。然后编码器的输出是B x S x h其中 h 是编码器的隐藏状态大小(它是一个 RNN)。

现在,在解码(在训练期间) 时,一次给定一个输入序列,所以输入是B x 1 x d,解码器产生一个形状的张量B x 1 x h。现在要计算上下文向量,我们需要将此解码器隐藏状态与编码器的编码状态进行比较。

所以,考虑你有两个形状张量T1 = B x S x hT2 = B x 1 x h。因此,如果您可以按如下方式进行批量矩阵乘法。

out = torch.bmm(T1, T2.transpose(1, 2))

B x S x h本质上,您是将一个形状张量与一个形状张量相乘B x h x 1,这将导致B x S x 1每个批次的注意力权重。

这里,注意力权重B x S x 1表示解码器当前隐藏状态和编码器所有隐藏状态之间的相似度得分。现在,您可以通过先转置将注意力权重与编码器的隐藏状态相乘,B x S x h这将产生一个 shape 的张量B x h x 1。如果你在 dim=2 处执行挤压,你会得到一个形状张量,B x h它是你的上下文向量。

这个上下文向量 ( B x h) 通常连接到解码器的隐藏状态 ( B x 1 x h,squeeze dim=1) 以预测下一个标记。

于 2018-06-13T04:30:27.120 回答
3

上图中描述的操作发生在DecoderSeq2Seq 模型一侧。这意味着编码器的输出 已经是批量的(带有小批量大小的样本)。因此,attn_weights张量也应该处于批处理模式。

因此,本质上,张量的第一个维度(NumPyzero术语中的 th 轴)是mini-batch size 的样本数。因此,我们需要这两个张量。attn_weightsencoder_outputstorch.bmm

于 2018-06-12T23:19:01.617 回答
2

虽然@wasiahmad 关于 seq2seq 的一般实现是正确的,但在提到的教程中没有批次(B=1),bmm只是过度设计,可以安全地替换matmul为完全相同的模型质量和性能。自己看,替换这个:

        attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                                 encoder_outputs.unsqueeze(0))
        output = torch.cat((embedded[0], attn_applied[0]), 1)

有了这个:

        attn_applied = torch.matmul(attn_weights,
                                 encoder_outputs)
        output = torch.cat((embedded[0], attn_applied), 1)

并运行笔记本。


另外请注意,虽然@wasiahmad 将编码器输入讨论为B x S x d,但在 pytorch 1.7.0 中,作为编码器主引擎的GRU(seq_len, batch, input_size)需要默认的输入格式。如果您想使用@wasiahmad 格式,请传递batch_first = True标志。

于 2020-12-05T21:41:30.677 回答