0

我已经使用 Seq2Seq 模型构建了一个基本的聊天机器人。当我在笔记本中按顺序运行代码时,机器人运行良好 - 即构建模型 - > 训练模型 - > 测试模型。

我现在想在训练后保存模型,加载模型然后测试模型。

但是,我遇到了问题/努力进一步进行。

这是我到目前为止所得到的:

保存模型

saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, 'model_final.ckpt')
这似乎工作正常

加载模型

saver = tf.train.import_meta_graph("model_final.ckpt.meta")
graph = tf.get_default_graph()
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
saver.restore(sess, "model_final.ckpt")
这似乎工作正常

当我按顺序运行时,下面的代码会完成输入问题、标记化并响应问题的工作。

prediction_c  = tf.argmax(model_c, 2)
result_c = sess_c.run(prediction_c,
                  feed_dict={enc_input_c: input_batch_c,
                             dec_input_c: output_batch_c,
                             targets_c: target_batch_c})

加载 Seq2Seq 模型后,我不确定 model_c、input_c 等变量如何获取值/初始化。

对于这个问题的基本性质,或者我试图实现的目标没有意义,我深表歉意;我刚刚开始研究张量。

4

1 回答 1

0

你调查过这个吗?

检查第 76-95 行的恢复代码:https ://github.com/keras-team/keras/blob/master/examples/lstm_seq2seq_restore.py

代码使用model.save和model.load分别保存和加载模型

正在恢复的模型是:https ://github.com/keras-team/keras/blob/master/examples/lstm_seq2seq.py

于 2019-05-05T07:27:58.867 回答