我有一个经过训练的 tensorflow 模型以检查点、.data、.meta 和 .index 文件的形式保存。该模型使用批量标准化。我尝试使用freeze_graph将其转换为 .pb 文件,该文件可以导入为.pb 文件from tensorflow.python.tools import freeze_graph
。它的输入也是一个.pb
文件,但它只有图形结构。我使用以下代码恢复模型
sess = tf.Session()
saver = tf.train.import_meta_graph(r'.\path\to\model\VanillaCNN.0000.meta')
saver.restore(sess, tf.train.latest_checkpoint(r'.\path\to\model'))
graph = tf.get_default_graph()
然后创建一个.pb
包含图形结构的文件
tf.train.write_graph(sess.graph_def, "", "model_proto.pb", False)
在此之后,我freeze_graph
用来生成一个.pb
包含图形结构和权重的文件。的输入freeze_graph
是
input_graph_path = r'.\path\to\model\model_proto.pb'
input_saver_def_path = ""
input_binary = False
input_checkpoint_path = r'.\path\to\model\VanillaCNN.0000'
output_node_names = "VanillaCNNoutput_10/layer_output"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_graph_path = r'.\path\to\model\frozen_model.pb'
clear_devices = False
initializer_nodes=""
执行为
freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,input_binary, input_checkpoint_path,output_node_names, restore_op_name,filename_tensor_name, output_graph_path,clear_devices,initializer_nodes)
这会创建frozen_model.pb
,当我尝试如下加载它时
def load_graph(frozen_graph_filename):
with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, input_map=None, return_elements=None, name="", op_dict=None, producer_op_list=None)
return graph
它引发以下错误
ValueError: graph_def is invalid at node 'VanillaCNNconv_0/VanillaCNNconv_0/cond/Assign': Input tensor 'VanillaCNNconv_0/VanillaCNNconv_0/cond/Assign/Switch:1' Cannot convert a tensor of type float32 to an input of type float32_ref.
我怎样才能解决这个问题?