0

我正在使用这个演示中的 retrain.python 文件。我得到不同类型的文件:

在此处输入图像描述

在此处输入图像描述

在此处输入图像描述

我想用检查点文件冻结 graph.pb,优化冻结的文件,然后将优化的文件转换为 tflite 文件,以便在 android 应用程序中使用它。

我尝试了不同的方法来冻结文件,但没有运气,

终端中不存在获取检查点文件

UnicodeDecodeError:“utf-8”编解码器无法解码位置 1 的字节 0x86:无效的起始字节

如何完成所有步骤并获取tflite文件以及如何合并labels.txt文件?

注意:这是我在终端中使用的命令:

python freeze_graph.py \ 
--input_graph=/home/automator/Desktop/retrain/code/graph/graph.pb \ 
--input_checkpoint=/home/automator/Desktop/retrain/code/tmp/model.ckpt \ 
--output_graph=/home/automator/Desktop/retrain/code/frozen.pb \ 
--output_node_names=output_node \
--input_saved_model_dir=/home/automator/Desktop/retrain/code/export/frozen.pb \ --output_node_names=outInput 

错误:检查点''不存在!

试过:

--input_checkpoint=/home/automator/Desktop/retrain/code/tmp/model.ckpt
--input_checkpoint=/home/automator/Desktop/retrain/code/tmp/model
--input_checkpoint=/home/automator/Desktop/retrain/code/tmp/modelmodel.ckpt
....

请帮忙!

4

2 回答 2

2

这是一个冻结图形的好脚本

import os
import argparse
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile


def load_graph_def(model_path, sess=None):
    if os.path.isfile(model_path):
        with gfile.FastGFile(model_path, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            tf.import_graph_def(graph_def, name='')
    else:
        sess = sess if sess is not None else tf.get_default_session()
        saver = tf.train.import_meta_graph(model_path + '.meta')
        saver.restore(sess, model_path)


def freeze_from_checkpoint(checkpoint_file, output_layer_name):

    model_folder = os.path.basename(checkpoint_file)
    output_graph = os.path.join(model_folder, checkpoint_file + '.pb')

    with tf.Session() as sess:

        load_graph_def(checkpoint_file)

        graph = tf.get_default_graph()
        input_graph_def = graph.as_graph_def()

        print("Exporting graph...")
        output_graph_def = graph_util.convert_variables_to_constants(
            sess,
            input_graph_def,
            output_layer_name.split(","))

        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('model_path')
    parser.add_argument('output_layer')
    args = parser.parse_args()
    freeze_from_checkpoint(checkpoint_file=args.model_path, output_layer_name=args.output_layer)

将其保存为 freeze_graph.py

称之为:python freeze_graph.py /home/automator/Desktop/retrain/code/tmp/model.data-000000-of-00001 "output_node_name"

于 2018-06-17T22:07:30.613 回答
0

鉴于您已meta graph保存,请尝试使用以下input_meta_graph参数:

python freeze_graph.py \ 
--input_meta_graph=/home/automator/Desktop/retrain/code/tmp/model.meta \ 
--input_checkpoint=/home/automator/Desktop/retrain/code/tmp/model.ckpt \ 
--input_binary=true \
--output_graph=/home/automator/Desktop/retrain/code/frozen.pb \ 
--output_node_names=output_node 

问题是您传递的--input_saved_model_dir参数覆盖了该input_meta_graph参数,但您似乎没有SavedModel

于 2018-06-18T16:44:30.087 回答