我用 char-rnn-tensorflow ( https://github.com/sherjilozair/char-rnn-tensorflow ) 训练了一个模型。模式保存到检查点。现在我想为模型提供 tensorflow 服务。
谷歌搜索了很多关于这个的教程,只发现这符合我的需求。当我将代码作为教程更改为以下内容时。它返回“node_name 不在图中”错误。
用“[n.name for n in tf.get_default_graph().as_graph_def().node]”得到图中所有节点的名称,超过 10000 对我来说很疯狂,我想知道哪个属于我。
所以这里的问题是,有没有更好的方法来查找我在训练时使用的节点名称。或任何更好的解决方案将检查点转换为 tensorlfow 服务中使用的 savemodel?
谢谢!
import tensorflow as tf
from model import Model
import argparse
import os
from six.moves import cPickle
from model import Model
# Build the signature_def_map.
# X: ry, pkeep: 1.0, Hin: rh, batchsize: 1
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--data_dir', type=str, default='data/obama',
help='data directory containing input.txt')
parser.add_argument('--output_node_names', type=str, default='node_name',
help='output node names')
parser.add_argument('--output_graph', type=str, default='output_graph',
help='output_graph')
parser.add_argument('--save_dir', type=str, default='save_train3',
help='directory to store checkpointed models')
args = parser.parse_args()
print(args)
freeze_graph(args)
def freeze_graph(args):
with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
saved_args = cPickle.load(f)
with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f:
chars, vocab = cPickle.load(f)
print(saved_args)
model = Model(saved_args, training=False)
with tf.Session() as sess:
tf.global_variables_initializer().run()
print(tf.global_variables())
saver = tf.train.Saver(tf.global_variables())
ckpt = tf.train.get_checkpoint_state(args.save_dir)
print(ckpt.model_checkpoint_path)
# We import the meta graph in the current default Graph
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta', clear_devices=True)
# We restore the weights
saver.restore(sess, ckpt.model_checkpoint_path)
# We use a built-in TF helper to export variables to constants
print(len([n.name for n in tf.get_default_graph().as_graph_def().node]))
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes
args.output_node_names.split(",") # The output node names are used to select the usefull nodes
)
# Finally we serialize and dump the output graph to the filesystem
with tf.gfile.GFile(args.output_graph + "model.pb", "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
if __name__ == "__main__":
main()