我已经成功使用 TensorflowSharp 和 Faster RCNN 有一段时间了;然而,我最近训练了一个 Retinanet 模型,验证了它在 python 中的工作,并创建了一个用于 Tensorflow 的冻结 pb 文件。对于 FRCNN,TensorflowSharp GitHub 存储库中有一个示例,展示了如何运行/获取此模型。对于 Retinanet,我尝试修改代码,但似乎没有任何效果。我有一个我尝试使用的 Retinanet 模型摘要,但对我来说应该使用什么并不明显。
对于 FRCNN,图形以这种方式运行:
var runner = m_session.GetRunner();
runner
.AddInput(m_graph["image_tensor"][0], tensor)
.Fetch(
m_graph["detection_boxes"][0],
m_graph["detection_scores"][0],
m_graph["detection_classes"][0],
m_graph["num_detections"][0]);
var output = runner.Run();
var boxes = (float[,,])output[0].GetValue(jagged: false);
var scores = (float[,])output[1].GetValue(jagged: false);
var classes = (float[,])output[2].GetValue(jagged: false);
var num = (float[])output[3].GetValue(jagged: false);
从 FRCNN 的模型摘要中,很明显输入(“image_tensor”)和输出(“detection_boxes”、“detection_scores”、“detection_classes”和“num_detections”)是什么。对于 Retinanet(我已经尝试过),它们不一样,我无法弄清楚它们应该是什么。上面代码的“获取”部分导致崩溃,我猜是因为我没有正确获取节点名称。
我不会在此处粘贴整个 Retinanet 摘要,但这里是前几个节点:
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) (None, None, None, 3 0
__________________________________________________________________________________________________
padding_conv1 (ZeroPadding2D) (None, None, None, 3 0 input_1[0][0]
__________________________________________________________________________________________________
conv1 (Conv2D) (None, None, None, 6 9408 padding_conv1[0][0]
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization) (None, None, None, 6 256 conv1[0][0]
__________________________________________________________________________________________________
conv1_relu (Activation) (None, None, None, 6 0 bn_conv1[0][0]
__________________________________________________________________________________________________
这是最后几个节点:
__________________________________________________________________________________________________
anchors_0 (Anchors) (None, None, 4) 0 P3[0][0]
__________________________________________________________________________________________________
anchors_1 (Anchors) (None, None, 4) 0 P4[0][0]
__________________________________________________________________________________________________
anchors_2 (Anchors) (None, None, 4) 0 P5[0][0]
__________________________________________________________________________________________________
anchors_3 (Anchors) (None, None, 4) 0 P6[0][0]
__________________________________________________________________________________________________
anchors_4 (Anchors) (None, None, 4) 0 P7[0][0]
__________________________________________________________________________________________________
regression_submodel (Model) (None, None, 4) 2443300 P3[0][0]
P4[0][0]
P5[0][0]
P6[0][0]
P7[0][0]
__________________________________________________________________________________________________
anchors (Concatenate) (None, None, 4) 0 anchors_0[0][0]
anchors_1[0][0]
anchors_2[0][0]
anchors_3[0][0]
anchors_4[0][0]
__________________________________________________________________________________________________
regression (Concatenate) (None, None, 4) 0 regression_submodel[1][0]
regression_submodel[2][0]
regression_submodel[3][0]
regression_submodel[4][0]
regression_submodel[5][0]
__________________________________________________________________________________________________
boxes (RegressBoxes) (None, None, 4) 0 anchors[0][0]
regression[0][0]
__________________________________________________________________________________________________
classification_submodel (Model) (None, None, 1) 2381065 P3[0][0]
P4[0][0]
P5[0][0]
P6[0][0]
P7[0][0]
__________________________________________________________________________________________________
clipped_boxes (ClipBoxes) (None, None, 4) 0 input_1[0][0]
boxes[0][0]
__________________________________________________________________________________________________
classification (Concatenate) (None, None, 1) 0 classification_submodel[1][0]
classification_submodel[2][0]
classification_submodel[3][0]
classification_submodel[4][0]
classification_submodel[5][0]
__________________________________________________________________________________________________
filtered_detections (FilterDete [(None, 300, 4), (No 0 clipped_boxes[0][0]
classification[0][0]
==================================================================================================
Total params: 36,382,957
Trainable params: 36,276,717
Non-trainable params: 106,240
任何有关如何解决“获取”部分的帮助将不胜感激。
编辑:
为了更深入地研究这一点,我找到了一个 python 函数来打印 .pb 文件中的操作名称。为 FRCNN .pb 文件执行此操作时,它清楚地给出了输出节点名称,如下所示(仅发布 python 函数输出的最后几行)。
import/SecondStagePostprocessor/BatchMultiClassNonMaxSuppression/map/TensorArrayStack_4/TensorArrayGatherV3
import/SecondStagePostprocessor/ToFloat_1
import/add/y
import/add
import/detection_boxes
import/detection_scores
import/detection_classes
import/num_detections
如果我对 Retinanet .pb 文件做同样的事情,那么输出是什么并不明显。这是python函数的最后几行。
import/filtered_detections/map/while/NextIteration_4
import/filtered_detections/map/while/Exit_2
import/filtered_detections/map/while/Exit_3
import/filtered_detections/map/while/Exit_4
import/filtered_detections/map/TensorArrayStack/TensorArraySizeV3
import/filtered_detections/map/TensorArrayStack/range/start
import/filtered_detections/map/TensorArrayStack/range/delta
import/filtered_detections/map/TensorArrayStack/range
import/filtered_detections/map/TensorArrayStack/TensorArrayGatherV3
import/filtered_detections/map/TensorArrayStack_1/TensorArraySizeV3
import/filtered_detections/map/TensorArrayStack_1/range/start
import/filtered_detections/map/TensorArrayStack_1/range/delta
import/filtered_detections/map/TensorArrayStack_1/range
import/filtered_detections/map/TensorArrayStack_1/TensorArrayGatherV3
import/filtered_detections/map/TensorArrayStack_2/TensorArraySizeV3
import/filtered_detections/map/TensorArrayStack_2/range/start
import/filtered_detections/map/TensorArrayStack_2/range/delta
import/filtered_detections/map/TensorArrayStack_2/range
import/filtered_detections/map/TensorArrayStack_2/TensorArrayGatherV3
作为参考,这是我使用的 python 函数:
def printTensors(pb_file):
# read pb into graph_def
with tf.gfile.GFile(pb_file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# import graph_def
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)
# print operations
for op in graph.get_operations():
print(op.name)
希望这可以帮助。