2

我已经用 tensorflow 训练了一个模型,并在训练期间使用了批量标准化。批量标准化需要用户传递一个名为 的布尔值is_training来设置模型是处于训练阶段还是测试阶段。

当模型被训练时,is_training被设置为一个常数,如下所示

is_training = tf.constant(True, dtype=tf.bool, name='is_training')

我保存了训练好的模型,文件包括检查点、.meta 文件、.index 文件和 .data。我想恢复模型并使用它运行推理。无法重新训练模型。所以,我想恢复现有模型,将值设置is_trainingFalse,然后将模型保存回来。如何编辑与该节点关联的布尔值并再次保存模型?

4

1 回答 1

4

您可以使用 的input_map参数tf.train.import_meta_graph将图张量重新映射到更新的值。

config = tf.ConfigProto(allow_soft_placement=True)
with tf.Session(config=config) as sess:
    # define the new is_training tensor
    is_training = tf.constant(False, dtype=tf.bool, name='is_training')

    # now import the graph using the .meta file of the checkpoint
    saver = tf.train.import_meta_graph(
    '/path/to/model.meta', input_map={'is_training:0':is_training})

    # restore all weights using the model checkpoint 
    saver.restore(sess, '/path/to/model')

    # save updated graph and variables values
    saver.save(sess, '/path/to/new-model-name')
于 2017-08-17T16:38:56.097 回答