我正在尝试开发具有注意力机制的聊天机器人。但它给出了这样的错误。我的 x_train 输入形状是 (None, 27),输出形状是 (None, 27, 8870)。但我无法正确识别错误。
def chatbot_model(embedding_size, max_sentence_length, vocab_size, embedding_matrix, batch_size=None):
if batch_size:
encoder_inputs = Input(batch_shape=(batch_size, max_sentence_length, ), name='encoder_inputs')
decoder_inputs = Input(batch_shape=(batch_size, max_sentence_length, ), name='decoder_inputs')
else:
encoder_inputs = Input(shape=(max_sentence_length, ), name='encoder_inputs')
decoder_inputs = Input(shape=(max_sentence_length, ), name='decoder_inputs')
embedding_layer = Embedding(vocab_size, embedding_size, weights=[embedding_matrix], input_length=max_sentence_length)
encoder_inputs_embed = embedding_layer(encoder_inputs)
decoder_inputs_embed = embedding_layer(decoder_inputs)
encoder_lstm = Bidirectional(LSTM(embedding_size, return_sequences=True, return_state=True, name='encoder_lstm'), name='bidirectional_encoder')
encoder_out, encoder_fwd_state_h, encoder_fwd_state_c, encoder_back_state_h, encoder_back_state_c = encoder_lstm(encoder_inputs_embed)
state_h = Concatenate()([encoder_fwd_state_h, encoder_back_state_h])
state_c = Concatenate()([encoder_fwd_state_c, encoder_back_state_c])
enc_states = [state_h, state_c]
decoder_lstm = LSTM(embedding_size*2, return_sequences=True, return_state=True, name='decoder_lstm')
decoder_out, decoder_state, *_ = decoder_lstm(
decoder_inputs_embed, initial_state=enc_states
)
attn_layer = AttentionLayer(name='attention_layer')
attn_out, attn_states = attn_layer([encoder_out, decoder_out])
decoder_concat_input = Concatenate(axis=-1, name='concat_layer')([decoder_out, attn_out])
print('decoder_concat_input', decoder_concat_input)
dense = Dense(vocab_size, activation='softmax', name='softmax_layer')
dense_time = TimeDistributed(dense, name='time_distributed_layer')
decoder_pred = dense_time(decoder_concat_input)
full_model = Model(inputs=[encoder_inputs, decoder_inputs], outputs=decoder_pred)
full_model.compile(optimizer='adam', loss='categorical_crossentropy')
full_model.summary()
""" Inference model """
batch_size = 1
encoder_inf_inputs = Input(batch_shape=(batch_size, max_sentence_length, ), name='encoder_inf_inputs')
encoder_inf_inputs_embed = embedding_layer(encoder_inf_inputs)
encoder_inf_out, encoder_inf_fwd_state_h, encoder_inf_fwd_state_c, encoder_inf_back_state_h, encoder_inf_back_state_c = encoder_lstm(encoder_inf_inputs_embed)
inf_state_h = Concatenate()([encoder_inf_fwd_state_h, encoder_inf_back_state_h])
inf_state_c = Concatenate()([encoder_inf_fwd_state_c, encoder_inf_back_state_c])
enc_inf_states = [inf_state_h, state_c]
encoder_model = Model(inputs=encoder_inf_inputs, outputs=[encoder_inf_out, encoder_inf_fwd_state_h, encoder_inf_fwd_state_c, encoder_inf_back_state_h, encoder_inf_back_state_c])
decoder_inf_inputs = Input(batch_shape=(batch_size, 1, ), name='decoder_word_inputs')
decoder_inf_inputs_embed = embedding_layer(decoder_inf_inputs)
encoder_inf_states = Input(batch_shape=(batch_size, max_sentence_length, 2*embedding_size), name='encoder_inf_states')
decoder_init_state_h = Input(batch_shape=(batch_size, 2*embedding_size), name='decoder_init_state_h')
decoder_init_state_c = Input(batch_shape=(batch_size, 2*embedding_size), name='decoder_init_state_c')
decoder_init_states = [decoder_init_state_h, decoder_init_state_c]
decoder_inf_out, decoder_inf_state_h, decoder_inf_state_c = decoder_lstm(decoder_inf_inputs_embed, initial_state=decoder_init_states)
decoder_inf_states = [decoder_inf_state_h, decoder_inf_state_c]
attn_inf_out, attn_inf_states = attn_layer([encoder_inf_states, decoder_inf_out])
decoder_inf_concat = Concatenate(axis=-1, name='concat')([decoder_inf_out, attn_inf_out])
decoder_inf_pred = TimeDistributed(dense)(decoder_inf_concat)
decoder_model = Model(inputs=[encoder_inf_states, decoder_init_states, decoder_inf_inputs],
outputs=[decoder_inf_pred, attn_inf_states, decoder_inf_states])
return full_model, encoder_model, decoder_model
它给出了这样的错误:
AssertionError Traceback (most recent call last)
in () ----> 1 full_model.fit(x_train[:1000, :], outs, epochs=1, batch_size=BATCH_SIZE)
AssertionError:在用户代码中:
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:571 train_function *
outputs = self.distribute_strategy.run(
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:951 run **
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2290 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py:2649 _call_for_each_replica
return fn(*args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py:531 train_step **
y_pred = self(x, training=True)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py:927 __call__
outputs = call_fn(cast_inputs, *args, **kwargs)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/network.py:719 call
convert_kwargs_to_constants=base_layer_utils.call_context().saving)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/network.py:899 _run_internal_graph
assert str(id(x)) in tensor_dict, 'Could not compute output ' + str(x)
AssertionError: Could not compute output Tensor("time_distributed_layer/Identity:0", shape=(None, 27, 8870), dtype=float32)