我已经使用 Seq2Seq 模型构建了一个基本的聊天机器人。当我在笔记本中按顺序运行代码时,机器人运行良好 - 即构建模型 - > 训练模型 - > 测试模型。
我现在想在训练后保存模型,加载模型然后测试模型。
但是,我遇到了问题/努力进一步进行。
这是我到目前为止所得到的:
保存模型
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess, 'model_final.ckpt')
这似乎工作正常
加载模型
saver = tf.train.import_meta_graph("model_final.ckpt.meta")
graph = tf.get_default_graph()
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
saver.restore(sess, "model_final.ckpt")
这似乎工作正常
当我按顺序运行时,下面的代码会完成输入问题、标记化并响应问题的工作。
prediction_c = tf.argmax(model_c, 2)
result_c = sess_c.run(prediction_c,
feed_dict={enc_input_c: input_batch_c,
dec_input_c: output_batch_c,
targets_c: target_batch_c})
加载 Seq2Seq 模型后,我不确定 model_c、input_c 等变量如何获取值/初始化。
对于这个问题的基本性质,或者我试图实现的目标没有意义,我深表歉意;我刚刚开始研究张量。