3

我有一个经过训练的 tensorflow 模型以检查点、.data、.meta 和 .index 文件的形式保存。该模型使用批量标准化。我尝试使用freeze_graph将其转换为 .pb 文件,该文件可以导入为.pb 文件from tensorflow.python.tools import freeze_graph。它的输入也是一个.pb文件,但它只有图形结构。我使用以下代码恢复模型

sess = tf.Session()
saver = tf.train.import_meta_graph(r'.\path\to\model\VanillaCNN.0000.meta') 
saver.restore(sess, tf.train.latest_checkpoint(r'.\path\to\model'))
graph = tf.get_default_graph()

然后创建一个.pb包含图形结构的文件

tf.train.write_graph(sess.graph_def, "", "model_proto.pb", False)

在此之后,我freeze_graph用来生成一个.pb包含图形结构和权重的文件。的输入freeze_graph

input_graph_path = r'.\path\to\model\model_proto.pb'
input_saver_def_path = ""
input_binary = False
input_checkpoint_path = r'.\path\to\model\VanillaCNN.0000'
output_node_names = "VanillaCNNoutput_10/layer_output"
restore_op_name = "save/restore_all"
filename_tensor_name = "save/Const:0"
output_graph_path = r'.\path\to\model\frozen_model.pb'
clear_devices = False
initializer_nodes=""

执行为

freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,input_binary, input_checkpoint_path,output_node_names, restore_op_name,filename_tensor_name, output_graph_path,clear_devices,initializer_nodes)

这会创建frozen_model.pb,当我尝试如下加载它时

def load_graph(frozen_graph_filename):
    with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, input_map=None, return_elements=None, name="", op_dict=None, producer_op_list=None)
    return graph

它引发以下错误

ValueError: graph_def is invalid at node 'VanillaCNNconv_0/VanillaCNNconv_0/cond/Assign': Input tensor 'VanillaCNNconv_0/VanillaCNNconv_0/cond/Assign/Switch:1' Cannot convert a tensor of type float32 to an input of type float32_ref.

我怎样才能解决这个问题?

4

0 回答 0