我已经使用 Tensorflow Estimator API 训练了一个自定义 CNN 模型。我已成功冻结图表,但转换为 UFF 失败并引发以下错误:
'KeyError: u'IteratorGetNext:1'
进行上述转换的代码:
frozen_graph_filename = "Frozen_model.pb"
TMP_UFF_FILENAME = "output.uff"
output_name = "sigmoid"
uff_model = uff.from_tensorflow_frozen_model(
frozen_file=frozen_graph_filename,
output_nodes=[output_name],
output_filename=TMP_UFF_FILENAME,
text=False,
)
图中节点的名称是,
prefix/OneShotIterator
prefix/IteratorGetNext
prefix/Reshape/shape
prefix/Reshape
prefix/Reshape_1/shape
prefix/Reshape_1
prefix/conv1/kernel
prefix/conv1/bias
.
.
.
prefix/logits/MatMul
prefix/logits/BiasAdd
prefix/sigmoid
那么有没有办法移除前两个 Iterator 节点呢?它们在训练环境之外毫无用处。我也使用过tf.graph_util.remove_training_nodes
,但它并没有缓解我面临的问题。