你好StackOverflow社区!
我正在尝试为带有Attention的seq2seq ( Encoded-Decoded ) 模型创建推理模型。这是推理模型的定义。
model = compile_model(tf.keras.models.load_model(constant.MODEL_PATH, compile=False))
encoder_input = model.input[0]
encoder_output, encoder_h, encoder_c = model.layers[1].output
encoder_state = [encoder_h, encoder_c]
encoder_model = tf.keras.Model(encoder_input, encoder_state)
decoder_input = model.input[1]
decoder = model.layers[3]
decoder_new_h = tf.keras.Input(shape=(n_units,), name='input_3')
decoder_new_c = tf.keras.Input(shape=(n_units,), name='input_4')
decoder_input_initial_state = [decoder_new_h, decoder_new_c]
decoder_output, decoder_h, decoder_c = decoder(decoder_input, initial_state=decoder_input_initial_state)
decoder_output_state = [decoder_h, decoder_c]
# These lines cause an error
context = model.layers[4]([encoder_output, decoder_output])
decoder_combined_context = model.layers[5]([context, decoder_output])
output = model.layers[6](decoder_combined_context)
output = model.layers[7](output)
# end
decoder_model = tf.keras.Model([decoder_input] + decoder_input_initial_state, [output] + decoder_output_state)
return encoder_model, decoder_model
当我运行此代码时,出现以下错误。
ValueError: Graph disconnected: cannot obtain value for tensor Tensor("input_5:0", shape=(None, None, 20), dtype=float32) at layer "lstm_4". The following previous layers were accessed without issue: ['lstm_5']
如果我排除一个注意块,模型将完全没有任何错误。
model = compile_model(tf.keras.models.load_model(constant.MODEL_PATH, compile=False))
encoder_input = model.input[0]
encoder_output, encoder_h, encoder_c = model.layers[1].output
encoder_state = [encoder_h, encoder_c]
encoder_model = tf.keras.Model(encoder_input, encoder_state)
decoder_input = model.input[1]
decoder = model.layers[3]
decoder_new_h = tf.keras.Input(shape=(n_units,), name='input_3')
decoder_new_c = tf.keras.Input(shape=(n_units,), name='input_4')
decoder_input_initial_state = [decoder_new_h, decoder_new_c]
decoder_output, decoder_h, decoder_c = decoder(decoder_input, initial_state=decoder_input_initial_state)
decoder_output_state = [decoder_h, decoder_c]
# These lines cause an error
# context = model.layers[4]([encoder_output, decoder_output])
# decoder_combined_context = model.layers[5]([context, decoder_output])
# output = model.layers[6](decoder_combined_context)
# output = model.layers[7](output)
# end
decoder_model = tf.keras.Model([decoder_input] + decoder_input_initial_state, [decoder_output] + decoder_output_state)
return encoder_model, decoder_model