我正在将运行此笔记本时获得的冻结推理图转换为 TFLite 模型。
我修改了从 TensorFlow文档中找到的代码,并在 Google Colab 中运行它:
import tensorflow as tf
path = "/content/drive/MyDrive/real_frozen_inference_graph.pb"
input = ["image_tensor"]
output = ["detection_boxes", "detection_scores", "detection_multiclass_scores", "detection_classes", "num_detections", "raw_detection_boxes", "raw_detection_scores"]
# Convert the model
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(path, input_arrays=input, output_arrays=output)
print("starting conversion")
tflite_model = converter.convert()
print("done converting")
# Save the model.
with open('/content/drive/MyDrive/model.tflite', 'wb') as f:
f.write(tflite_model)
运行单元后,运行时几乎立即断开连接,这特别奇怪,因为我为 colab pro 付费......
我更改了上面的代码以在我的本地开发环境上运行,但仍然没有运气。我得到一长串原始二进制代码和类似这样的奇怪的东西打印到我的标准输出,但脚本终止时没有写入文件:
%220 = "tfl.add"(%219, %cst_12) {fused_activation_function = "NONE"} : (tensor<?x100xf32>, tensor<f32>) -> tensor<?x100xf32>
%221 = "tf.TensorArraySizeV3"(%handle_170, %210#5) {_class = ["loc:@Postprocessor/BatchMultiClassNonMaxSuppression/map/TensorArray_12"], device = ""} : (tensor<2x!tf.resource<tensor<*xf32>>>, tensor<f32>) -> tensor<i32>
%222 = "tfl.range"(%cst_10, %221, %cst_13) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
%223 = "tf.TensorArrayGatherV3"(%handle_170, %222, %210#5) {_class = ["loc:@Postprocessor/BatchMultiClassNonMaxSuppression/map/TensorArray_12"], device = "", element_shape = #tf.shape<100x2>} : (tensor<2x!tf.resource<tensor<*xf32>>>, tensor<?xi32>, tensor<f32>) -> tensor<?x100x2xf32>
%224 = "tf.TensorArraySizeV3"(%handle_172, %210#6) {_class = ["loc:@Postprocessor/BatchMultiClassNonMaxSuppression/map/TensorArray_13"], device = ""} : (tensor<2x!tf.resource<tensor<*xi32>>>, tensor<f32>) -> tensor<i32>
%225 = "tfl.range"(%cst_10, %224, %cst_13) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
%226 = "tf.TensorArrayGatherV3"(%handle_172, %225, %210#6) {_class = ["loc:@Postprocessor/BatchMultiClassNonMaxSuppression/map/TensorArray_13"], device = "", element_shape = #tf.shape<>} : (tensor<2x!tf.resource<tensor<*xi32>>>, tensor<?xi32>, tensor<f32>) -> tensor<?xi32>
%227 = "tfl.cast"(%226) : (tensor<?xi32>) -> tensor<?xf32>
"std.return"(%213, %216, %223, %220, %227, %181, %186) : (tensor<?x100x4xf32>, tensor<?x100xf32>, tensor<?x100x2xf32>, tensor<?x100xf32>, tensor<?xf32>, tensor<?x?x4xf32>, tensor<?x?x2xf32>) -> ()
}) {sym_name = "main", tf.entry_function = {control_outputs = "", inputs = "image_tensor", outputs = "detection_boxes,detection_scores,detection_multiclass_scores,detection_classes,num_detections,raw_detection_boxes,raw_detection_scores"}, type = (tensor<?x?x?x3x!tf.quint8>) -> (tensor<?x100x4xf32>, tensor<?x100xf32>, tensor<?x100x2xf32>, tensor<?x100xf32>, tensor<?xf32>, tensor<?x?x4xf32>, tensor<?x?x2xf32>)} : () -> ()
有人对这里出了什么问题有任何想法吗?是否将 TFLite 文件打印到我的标准输出?也许我忽略了一些明显的东西?我是 TensorFlow 的新手,因此不胜感激。