-1

我有一个用于字符级英语拼写校正的编码器-解码器模型,它是非常基本的东西,有两个 LSTM 编码器和另一个 LSTM 解码器。

但是,到目前为止,我一直在预先填充编码器输入序列,如下所示:

abc  -> -abc
defg -> defg
ad   -> --ad

接下来我一直将数据分成几个具有相同解码器输入长度的组,例如

train_data = {'15': [...], '16': [...], ...}

其中关键是解码器输入数据的长度,我一直在为循环中的每个长度训练一次模型。

但是,必须有更好的方法来做到这一点,例如在 EOS 之后或 SOS 字符之前进行填充等。但如果是这种情况,我将如何更改损失函数,以便该填充不计入损失?

4

1 回答 1

0

进行填充的标准方法是将其放在序列结束标记之后,但填充的位置应该很重要。

如何不将填充位置包含在损失中的技巧是在减少损失之前将它们屏蔽掉。假设PAD_ID变量包含您用于填充的符号的索引:

def custom_loss(y_true, y_pred):
    mask = 1 - K.cast(K.equal(y_true, PAD_ID), K.floatx())
    loss = K.categorical_crossentropy(y_true, y_pred) * mask    
    return K.sum(loss) / K.sum(mask)
于 2020-03-03T09:19:13.087 回答