我想在 Android 应用程序中使用我的 Tensorflow 算法。Tensorflow Android 示例首先下载包含模型定义和权重的 GraphDef(在 *.pb 文件中)。现在这应该来自我的 Scikit Flow 算法(Tensorflow 的一部分)。
乍一看,你只需要说 classifier.save('model/') 似乎很容易,但保存到该文件夹的文件不是 *.ckpt、*.def,当然也不是 *.pb。相反,您必须处理 *.pbtxt 和检查点(没有结尾)文件。
我被困在那里已经有一段时间了。这是一个导出内容的代码示例:
#imports
import tensorflow as tf
import tensorflow.contrib.learn as skflow
import tensorflow.contrib.learn.python.learn as learn
from sklearn import datasets, metrics
#skflow example
iris = datasets.load_iris()
feature_columns = learn.infer_real_valued_columns_from_input(iris.data)
classifier = learn.LinearClassifier(n_classes=3, feature_columns=feature_columns,model_dir="modeltest")
classifier.fit(iris.data, iris.target, steps=200, batch_size=32)
iris_predictions = list(classifier.predict(iris.data, as_iterable=True))
score = metrics.accuracy_score(iris.target, iris_predictions)
print("Accuracy: %f" % score)
你得到的文件是:
- 检查点
- 图.pbtxt
- 模型.ckpt-1.meta
- model.ckpt-1-00000-of-00001
- 模型.ckpt-200.meta
- model.ckpt-200-00000-of-00001
我发现的许多可能的解决方法都需要将 GraphDef 放在一个变量中(不知道如何使用 Scikit Flow)。或者似乎不需要使用 Scikit Flow 的 Tensorflow 会话。