4

我正在尝试在 pbtxt 文件中冻结包含 batchnorm 层(tf 1.1.0)的检查点。为此,请关注这些帖子和问题:

我使用这个功能:

freeze_and_prune_graph(model_path_and_name, output_file=None):
    """
    freezes a model trained and saved by the trainer by :
        - extracting the trainable variables between input_node and output_node
        - turning them to constants
        - changing the 1rst dim of input_node to None
        -saving the resulting graph as a single .pb file

    :param model_path_and_name: must finish by .ckpt, and the checkpoint must be composed of
    3+ files : .ckpt.index, .ckpt.meta, and .ckpt.data-0000X-of-0000Y

    :param model_path_and_name: path to the trained model
    :param output_file: file to save to. If None, model_path_and_name.[-ckpt][+pb]
    :return: None
    """
    config_proto = tf.ConfigProto(allow_soft_placement=True)

    with tf.Session(config=config_proto) as sess:
        new_saver = tf.train.import_meta_graph(model_path_and_name + '.meta', clear_devices=True)
        tf.get_default_session().run(tf.global_variables_initializer())
        tf.get_default_session().run(tf.local_variables_initializer())
        new_saver.restore(sess, model_path_and_name)

        # get graph definition
        gd = sess.graph.as_graph_def()
        # fix batch norm nodes
        for node in gd.node:
            if node.op == 'RefSwitch':
                node.op = 'Switch'
                for index in xrange(len(node.input)):
                    if 'moving_' in node.input[index]:
                        node.input[index] = node.input[index] + '/read'
            elif node.op == 'AssignSub':
                node.op = 'Sub'
                if 'use_locking' in node.attr: del node.attr['use_locking']
            elif node.op == 'AssignAdd':
                node.op = 'Add'
                if 'use_locking' in node.attr: del node.attr['use_locking']

        # tf.get_collection() returns a list. In this example we only want the
        input_node = sess.graph.get_tensor_by_name('input_node:0')
        new_shape = [None] + input_node.get_shape().as_list()[1:]

        trainables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
        new_graph_def = tf.graph_util.convert_variables_to_constants(sess, gd, ["output_node"],
                                                                     variable_names_whitelist=[t.name[:-2] for t in trainables] + ['output_node'])

        for node in new_graph_def.node:
            if node.name == 'input_node':
                node.attr['shape'].CopyFrom(attr_value_pb2.AttrValue(shape=tf.TensorShape(new_shape).as_proto()))
                break

        with tf.gfile.GFile(output_file, "wb") as f:
            f.write(new_graph_def.SerializeToString())
        print("{0} / {1} ops in the final graph.".format(len(new_graph_def.node), len(sess.graph.as_graph_def().node)))

这很顺利,并使用以下输出创建了 pbtxt 文件:

Converted 201 variables to const ops.
5287 / 41028 ops in the final graph.

然后我尝试pbtxt使用这个函数加载模型:

def load_frozen_graph(frozen_graph_file):
    """
    loads a graph frozen via freeze_and_prune_graph and returns the graph, its input placeholder and output tensor

    :param frozen_graph_file: .pb file to load
    :return: tf.graph, tf.placeholder, tf.tensor
    """
    # We load the protobuf file from the disk and parse it to retrieve the
    # unserialized graph_def
    with tf.gfile.GFile(frozen_graph_file, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # Then, we can use again a convenient built-in function to import a graph_def into the
    # current default Graph
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(
            graph_def,
            input_map=None,
            return_elements=None,
            name="prefix",
            op_dict=None,
            producer_op_list=None
        )

    input_images_placeholder = graph.get_tensor_by_name('prefix/input_node:0')
    input_phase_placeholder = None
    try:
        input_phase_placeholder = graph.get_tensor_by_name('prefix/phase:0')
    except KeyError:
        pass
    output = graph.get_tensor_by_name('prefix/output_node:0')

    return graph, input_images_placeholder, input_phase_placeholder, output

使用以下代码段:

graph, input_images_placeholder, is_training_placeholder, output = load_frozen_graph(model_pbtxt)
sess = tf.Session(config=tf_config, graph=graph)
feed_dict = {input_images_placeholder: prepared_input}
if is_training_placeholder is not None:
    feed_dict[is_training_placeholder] = False
ret = sess.run([output], feed_dict=feed_dict)

但是,这会导致以下错误:

FailedPreconditionError (see above for traceback):
    Attempting to use uninitialized value prefix/conv0/BatchNorm/batch_normalization/moving_mean
[
    [
        Node: prefix/conv0/BatchNorm/batch_normalization/moving_mean/read = Identity[
            T=DT_FLOAT,
            _class=["loc:@prefix/conv0/BatchNorm/batch_normalization/moving_mean"],
            _device="/job:localhost/replica:0/task:0/gpu:0"
        ](prefix/conv0/BatchNorm/batch_normalization/moving_mean)
    ]
]
[
    [
        Node: prefix/output_node/_381 = _Recv[
            client_terminated=false,
            recv_device="/job:localhost/replica:0/task:0/cpu:0",
            send_device="/job:localhost/replica:0/task:0/gpu:0",
            send_device_incarnation=1,
            tensor_name="edge_2447_prefix/output_node",
            tensor_type=DT_FLOAT,
            _device="/job:localhost/replica:0/task:0/cpu:0"
        ]()
    ]
]

跟随问题: TensorFlow:在变量初始化中“尝试使用未初始化的值” 我尝试初始化变量:

graph, input_images_placeholder, is_training_placeholder, output = load_frozen_graph(model_pbtxt)
sess = tf.Session(config=tf_config, graph=graph)
init = [tf.global_variables_initializer(), tf.local_variables_initializer()]
sess.run(init)

feed_dict = {input_images_placeholder: prepared_input}
if is_training_placeholder is not None:
    feed_dict[is_training_placeholder] = False
ret = sess.run([self.output], feed_dict=feed_dict)

但是,这会将错误更改为:

ValueError: Fetch argument <tf.Operation 'init' type=NoOp> cannot be interpreted as a Tensor. 
(Operation name: "init" op: "NoOp" is not an element of this graph.)

这似乎表明没有需要初始化的变量。

我错过了什么?如何冻结并重新加载 batch_normalization 层的相关值?

4

0 回答 0