3

我想用 OpenCV-DNN 包装注意力 OCR模型以增加推理时间。我正在使用官方 TF 模型repo中的 TF 代码。

对于使用 OpenCV-DNN 包装 TF 模型,我指的是此代码。需要“cv2.dnn.readNetFromTensorflow()冻结图”和“图结构”来读取 TF 模型。

我使用此代码片段从元检查点文件导入结构并将图形结构保存在.pbtxt文件中。

# load graph from meta file
tf.reset_default_graph()  
imported_meta = tf.train.import_meta_graph("attention_ocr_2017_08_09/model_demo_inference.ckpt.meta")

# restore graph structure, variables in session's graph
sess = tf.Session()
imported_meta.restore(sess, 'attention_ocr_2017_08_09/model_demo_inference.ckpt')
# write graph structure to a pbtxt file
tf.train.write_graph(sess.graph_def, './', 'train_attention.pbtxt', as_text=True)

冻结图,代码如下:

from tensorflow.python.tools import freeze_graph
freeze_graph.freeze_graph('train_attention.pbtxt', '', False, \
                          'attention_ocr_2017_08_09/model_demo_inference.ckpt', \
                          'AttentionOcr_v1_1/Softmax', \
                          'save/restore_all', 'save/Const:0', 'frozen_model.pb', True, "")

最终代码使用函数中的pbtxtpb文件cv2.dnn.readNetFromTensorflow()

# Wrap TF model in OpenCV DNN
import cv2

FROZEN_GRAPH = "frozen_model.pb"
PB_TXT = "train_attention.pbtxt"

img = cv2.imread('testdata/fsns_train_00.png')
blob = cv2.dnn.blobFromImage(img,1)

net = cv2.dnn.readNetFromTensorflow(FROZEN_GRAPH, PB_TXT)
out = net.forward()
out

遇到的错误是:

---------------------------------------------------------------------------
error                                     Traceback (most recent call last)
<ipython-input-128-09e46e8b88ed> in <module>
      9 blob = cv2.dnn.blobFromImage(img,1)
     10 
---> 11 net = cv2.dnn.readNetFromTensorflow(FROZEN_GRAPH, PB_TXT)
     12 out = net.forward()
     13 out

error: OpenCV(4.0.0) /Users/travis/build/skvark/opencv-python/opencv/modules/dnn/src/
tensorflow/tf_io.cpp:54: error: (-2:Unspecified error) 
FAILED: ReadProtoFromTextFile(param_file, param). 
Failed to parse GraphDef file: train_attention.pbtxt in function 'ReadTFNetParamsFromTextFileOrDie'

注意:输出节点名称是通过查看生成的图中的张量列表手动设置的:

# get names of all tensors
def get_names(graph=sess.graph):
    return [t.name for op in graph.get_operations() for t in op.values()]

l1 = get_names()
for ele in l1:
    print(ele)

我将非常感谢 SO 社区提供的任何帮助。

4

1 回答 1

2

就我而言,我试图.pbtxt通过我的 Google Colaboratory 访问存储在 github 存储库中的文件。我只需要这个文件,所以我没有克隆整个 repo,而是尝试使用!wget命令访问它。我做了: https://github.com/<username>/somethingelse/mask_rcnn_inception_v2_coco_2018_01_28.pbtxt。该文件似乎已下载,但是当我运行时cv2.dnn.readNetFromTensorflow,它向我抛出了一条错误消息:

OpenCV 错误:未指定的错误(失败:ReadProtoFromTextFileTF(param_file,param)。无法解析 GraphDef 文件:ssd_mobilenet_v1_coco_11_06_2017/graph.pbtxt)在 ReadTFNetParamsFromTextFileOrDie,文件 opencv-3.3.1/modules/dnn/src/tensorflow/tf_io.cpp,行72 opencv-3.3.1/modules/dnn/src/tensorflow/tf_io.cpp:72:错误:(-2)失败:ReadProtoFromTextFileTF(param_file,param)。无法解析 GraphDef 文件:函数 ReadTFNetParamsFromTextFileOrDie 中的 ssd_mobilenet_v1_coco_11_06_2017/graph.pbtxt

我意识到我应该使用rawgithub中的文件来下载,如下:

!wget https://raw.githubusercontent.com/---somethingelse as part of link---/mask_rcnn_inception_v2_coco_2018_01_28.pbtxt

由于我没有使用 raw,所以文件以 HTML 格式下载,而预期的文件是.pbtxt.

转到github存储库中的文件位置->单击右上角的原始选项->获取此原始页面的URL->使用

!wget https://raw.githubusercontent.com/ ---somethingelse as part of link---/mask_rcnn_inception_v2_coco_2018_01_28.pbtxt
于 2020-11-11T12:43:25.413 回答