3

我想在 Android 应用程序中使用我的 Tensorflow 算法。Tensorflow Android 示例首先下载包含模型定义和权重的 GraphDef(在 *.pb 文件中)。现在这应该来自我的 Scikit Flow 算法(Tensorflow 的一部分)。

乍一看,你只需要说 classifier.save('model/') 似乎很容易,但保存到该文件夹​​的文件不是 *.ckpt、*.def,当然也不是 *.pb。相反,您必须处理 *.pbtxt 和检查点(没有结尾)文件。

我被困在那里已经有一段时间了。这是一个导出内容的代码示例:

#imports
import tensorflow as tf
import tensorflow.contrib.learn as skflow
import tensorflow.contrib.learn.python.learn as learn
from sklearn import datasets, metrics

#skflow example
iris = datasets.load_iris()
feature_columns = learn.infer_real_valued_columns_from_input(iris.data)
classifier = learn.LinearClassifier(n_classes=3, feature_columns=feature_columns,model_dir="modeltest")
classifier.fit(iris.data, iris.target, steps=200, batch_size=32)
iris_predictions = list(classifier.predict(iris.data, as_iterable=True))
score = metrics.accuracy_score(iris.target, iris_predictions)
print("Accuracy: %f" % score)

你得到的文件是:

  • 检查点
  • 图.pbtxt
  • 模型.ckpt-1.meta
  • model.ckpt-1-00000-of-00001
  • 模型.ckpt-200.meta
  • model.ckpt-200-00000-of-00001

我发现的许多可能的解决方法都需要将 GraphDef 放在一个变量中(不知道如何使用 Scikit Flow)。或者似乎不需要使用 Scikit Flow 的 Tensorflow 会话。

4

1 回答 1

2

要保存为 pb 文件,您需要从构建的图形中提取 graph_def。你可以这样做——

from tensorflow.python.framework import tensor_shape, graph_util
from tensorflow.python.platform import gfile
sess = tf.Session()
final_tensor_name = 'results:0'     #Replace final_tensor_name with name of the final tensor in your graph
#########Build your graph and train########
## Your tensorflow code to build the graph
###########################################

outpt_filename = 'output_graph.pb'
output_graph_def = sess.graph.as_graph_def()
with gfile.FastGFile(outpt_filename, 'wb') as f:
  f.write(output_graph_def.SerializeToString())

如果要将训练的变量转换为常量(以避免使用 ckpt 文件加载权重),可以使用:

output_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), [final_tensor_name])

希望这可以帮助!

于 2016-10-02T01:36:09.507 回答