我正在遵循 XLA AOT 编译的官方指南(https://www.tensorflow.org/xla/tfcompile),并且编译示例工作得很好(在 aot/tests 内)。
但是后来我想编译一些稍微大一点的模型,然后出现了一个问题:如果 XLA AOT 需要一个冻结图作为输入(正如我从指南中理解的那样)并且 TensorFlow 2 不再支持冻结图,那么 XLA 现在期望什么输入?
我正在遵循 XLA AOT 编译的官方指南(https://www.tensorflow.org/xla/tfcompile),并且编译示例工作得很好(在 aot/tests 内)。
但是后来我想编译一些稍微大一点的模型,然后出现了一个问题:如果 XLA AOT 需要一个冻结图作为输入(正如我从指南中理解的那样)并且 TensorFlow 2 不再支持冻结图,那么 XLA 现在期望什么输入?
似乎仍有方法可以在 TensorFlow 2 中冻结图。我按照这篇文章创建了一个冻结图,然后它可以编译它:https ://leimao.github.io/blog/Save-Load-Inference- From-TF2-Frozen-Graph/
# Convert Keras model to ConcreteFunction
full_model = tf.function(lambda x: model(x))
full_model = full_model.get_concrete_function(
tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))
# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
layers = [op.name for op in frozen_func.graph.get_operations()]
print("-" * 50)
print("Frozen model layers: ")
for layer in layers:
print(layer)
print("-" * 50)
print("Frozen model inputs: ")
print(frozen_func.inputs)
print("Frozen model outputs: ")
print(frozen_func.outputs)
# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
logdir="./frozen_models",
name="frozen_graph.pb",
as_text=False)