def decode_sequence_seq2seq_model_with_just_lstm(
input_sequence, encoder_model, decoder_model
):
# Encode the input as state vectors.
e_out, e_h, e_c = encoder_model.predict(input_sequence)
# Generate empty target sequence of length 1.
target_seq = np.zeros((1, 1))
# Populate the first word of target sequence with the start word.
target_seq[0, 0] = target_word_index[start_token]
stop_condition = False
decoded_sentence = ''
while not stop_condition:
output_tokens, h, c = decoder_model.predict(
[target_seq] + [e_out, e_h, e_c]
)
# Sample a token
sampled_token_index = np.argmax(output_tokens[0, -1, )
sampled_token = reverse_target_word_index[sampled_token_index]
if sampled_token != end_token:
decoded_sentence += ' ' + sampled_token
# Exit condition: either hit max length or find stop word.
if (sampled_token == end_token) or (len(decoded_sentence.split()) >= (max_summary_len - 1)):
stop_condition = True
# Update the target sequence (of length 1).
target_seq = np.zeros((1, 1))
target_seq[0, 0] = sampled_token_index
# Update internal states
e_h, e_c = h, c
return decoded_sentence
问问题
20 次