2

所以我有一个包含这三个文件的文件夹:

  1. model.data-00000-of-00001
  2. 模型索引
  3. 模型.元

我想将 .data 文件转换为 .pb 文件。我已经检查了几乎一个链接,但我一直遇到错误并且卡住了。

这是我用来将检查点文件转换为 .pb 的脚本:

import tensorflow as tf
#Step 1 
#import the model metagraph
saver = tf.train.import_meta_graph('model.meta', clear_devices=True)
#make that as the default graph
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
sess = tf.Session()
#now restore the variables
saver.restore(sess, "model")
tf.compat.v1.disable_eager_execution()
#Step 2
# Find the output name
graph = tf.get_default_graph()
for op in graph.get_operations(): 
  print (op.name)

#Step 3
from tensorflow.python.platform import gfile
from tensorflow.python.framework import graph_util

output_node_names = 'b1,b2,b3,b5,b6,bidirectional_rnn/bw/basic_lstm_cell/bias,bidirectional_rnn/bw/basic_lstm_cell/kernel,bidirectional_rnn/fw/basic_lstm_cell/bias,bidirectional_rnn/fw/basic_lstm_cell/kernel,h1,h2,h3,h5,h6'
output_graph_def = graph_util.convert_variables_to_constants(
        sess, # The session
        input_graph_def, # input_graph_def is useful for retrieving the nodes 
        output_node_names.split(",")  )    

#Step 4
#output folder
output_fld ='./'
#output pb file name
output_model_file = 'model.pb'
from tensorflow.python.framework import graph_io
#write the graph
graph_io.write_graph(output_graph_def, output_fld, output_model_file, as_text=False)

这里的 output_nodes 来自 DeepSpeech (Mozilla) 版本 0.1.0。这个版本提供了 491.0MB 的默认模型,我生成的是 490.9MB。但它们都是从相同的检查点生成的。我想进一步训练我的检查点文件,但在此之前,我想看看我是否也可以先将其冻结当我尝试转录音频文件时,这就是我不断得到的:

Loading model from file model.pb
Loaded model in 0.423s.
Running inference.
Error running session: Not found: FeedInputs: unable to find feed output input_lengths
None
Inference took 0.011s for 1.975s audio file.

请帮忙!

4

0 回答 0