我用 tensorflow 定义了一个深度 CNN,包括一个批量标准化操作,即,我的代码可能如下所示:
def network(input):
...
input = tf.layers.batch_normalization(input, ...)
...
假设网络已经训练好,并且检查点文件已经保存。现在我想用这个模型进行推理。通常,我可以network(input)
再次调用该函数,除了将参数传递training=False
给tf.layers.batch_normalization()
,然后从检查点文件中恢复权重。
但是,我更喜欢用它tf.import_meta_graph
来重建我的网络,因为函数中的代码network(input)
可以更改。
但是现在我如何在推理模式下设置批量标准化操作呢?由于我无法访问 function tf.layers.batch_normalization()
,因此我很难解决这个问题。