我正在使用自定义图像集来训练使用 Tensorflow API 的神经网络。在成功的训练过程之后,我得到了这些包含不同训练变量值的检查点文件。我现在想从这些检查点文件中获取推理模型,我找到了执行此操作的脚本,然后我可以使用它来生成 deepdream 图像,如本教程中所述。问题是当我使用以下方法加载模型时:
import tensorflow as tf
model_fn = 'export'
graph = tf.Graph()
sess = tf.InteractiveSession(graph=graph)
with tf.gfile.FastGFile(model_fn, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
t_input = tf.placeholder(np.float32, name='input')
imagenet_mean = 117.0
t_preprocessed = tf.expand_dims(t_input-imagenet_mean, 0)
tf.import_graph_def(graph_def, {'input':t_preprocessed})
我收到此错误:
graph_def.ParseFromString(f.read())
self.MergeFromString(序列化)
raise message_mod.DecodeError('Unexpected end-group tag.') google.protobuf.message.DecodeError: Unexpected end-group tag。
该脚本需要一个协议缓冲区文件,我不确定我用来生成推理模型的脚本是否给了我原型缓冲区文件。
有人可以建议我做错了什么,或者有更好的方法来实现这一点。我只是想将张量生成的检查点文件转换为原型缓冲区。
谢谢