1

我正在开发 seq2seq 聊天机器人。我会问你,如何在 val_acc 计数时忽略聊天机器人响应中的 PAD 符号。

例如,我的模型生成响应:[I, am, reading, a, book, PAD, PAD, PAD, PAD, PAD]

但是,正确的反应应该是:[My, brother, is, playing, fotball,PAD, PAD, PAD, PAD, PAD].

在这种情况下,聊天机器人的反应完全错误,但由于填充符号,val_acc 为 50%。

我使用 Keras,编码器-解码器模型(https://blog.keras.io/a-ten-minute-introduction-to-sequence-to-sequence-learning-in-keras.html)和教师强制

我的代码在这里:

encoder_inputs = Input(shape=(sentenceLength,), name="Encoder_input")
encoder = LSTM(n_units, return_state=True, name='Encoder_lstm')
Shared_Embedding = Embedding(output_dim=embedding, input_dim=vocab_size, name="Embedding", mask_zero='True') 
word_embedding_context = Shared_Embedding(encoder_inputs)
encoder_outputs, state_h, state_c = encoder(word_embedding_context)
encoder_states = [state_h, state_c]

decoder_inputs = Input(shape=(None,), name="Decoder_input")
decoder_lstm = LSTM(n_units, return_sequences=True, return_state=True, name="Decoder_lstm")

word_embedding_answer = Shared_Embedding(decoder_inputs)
decoder_outputs, _, _ = decoder_lstm(word_embedding_answer, initial_state=encoder_states)
decoder_dense = Dense(vocab_size, activation='softmax', name="Dense_layer")
decoder_outputs = decoder_dense(decoder_outputs)
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)

编码器输入是句子,其中每个单词都是整数,0 是填充:[1,2,5,4,3,0,0,0] -> 用户问题解码器输入也是句子,其中每个单词都是整数,0 是填充和100 是符号 GO:[100,8,4,2,0,0,0,0,0]] -> 聊天机器人响应移位了一个时间戳解码器输出是句子,其中单词是整数,这些整数是一个热编码的: [8,4,2,0,0,0,0,0, 0]] -> 聊天机器人响应(整数是一种热编码。)

问题是,val_acc 太高了,模型预测的句子也完全错误。我认为这是由于填充引起的。我的模型有问题吗?我应该在我的解码器中添加一些其他掩码吗?

这是我的图表: 在此处输入图像描述 在此处输入图像描述

4

1 回答 1

1

您是对的,这是因为该教程不使用Masking文档)来忽略这些填充值并显示相等输入输出长度的示例。在您的情况下,模型仍将输入输出 PAD,但掩码将忽略它们。例如,要屏蔽编码器:

# Define an input sequence and process it.
encoder_inputs = Input(shape=(None, num_encoder_tokens))
encoder_inputs = Masking()(encoder_inputs) # Assuming PAD is zeros
encoder = LSTM(latent_dim, return_state=True)
# Now the LSTM will ignore the PADs when encoding
# by skipping those timesteps that are masked
encoder_outputs, state_h, state_c = encoder(encoder_inputs)
于 2018-05-12T11:24:29.603 回答