3

我用 char-rnn-tensorflow ( https://github.com/sherjilozair/char-rnn-tensorflow ) 训练了一个模型。模式保存到检查点。现在我想为模型提供 tensorflow 服务。

谷歌搜索了很多关于这个的教程,只发现符合我的需求。当我将代码作为教程更改为以下内容时。它返回“node_name 不在图中”错误。

用“[n.name for n in tf.get_default_graph().as_graph_def().node]”得到图中所有节点的名称,超过 10000 对我来说很疯狂,我想知道哪个属于我。

所以这里的问题是,有没有更好的方法来查找我在训练时使用的节点名称。或任何更好的解决方案将检查点转换为 tensorlfow 服务中使用的 savemodel?

谢谢!

import tensorflow as tf
from model import Model
import argparse
import os
from six.moves import cPickle
from model import Model
#  Build the signature_def_map.
# X: ry, pkeep: 1.0, Hin: rh, batchsize: 1
def main():
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('--data_dir', type=str, default='data/obama',
                        help='data directory containing input.txt')
    parser.add_argument('--output_node_names', type=str, default='node_name',
                        help='output node names')
    parser.add_argument('--output_graph', type=str, default='output_graph',
                        help='output_graph')
    parser.add_argument('--save_dir', type=str, default='save_train3',
                        help='directory to store checkpointed models')

    args = parser.parse_args()
    print(args)
    freeze_graph(args)

def freeze_graph(args):
    with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
        saved_args = cPickle.load(f)
    with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f:
        chars, vocab = cPickle.load(f)
    print(saved_args)
    model = Model(saved_args, training=False)
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        print(tf.global_variables())
        saver = tf.train.Saver(tf.global_variables())
        ckpt = tf.train.get_checkpoint_state(args.save_dir)
        print(ckpt.model_checkpoint_path)
    # We import the meta graph in the current default Graph
        saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta', clear_devices=True)
        # We restore the weights
        saver.restore(sess, ckpt.model_checkpoint_path)
        # We use a built-in TF helper to export variables to constants
        print(len([n.name for n in tf.get_default_graph().as_graph_def().node]))
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess, # The session is used to retrieve the weights
            tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes 
            args.output_node_names.split(",") # The output node names are used to select the usefull nodes
        ) 
        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(args.output_graph + "model.pb", "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))

if __name__ == "__main__":
    main()
4

0 回答 0