我有一批来自原始图像的裁剪图像,我必须在这些图像上执行对象检测,我正在尝试应用 tensorflow NMS 操作。
我查看了 tensorflow api docs,发现了tf.image.combined_non_max_suppression()
,但我无法正确理解它。
我的管道中的流程分为两步。
- 我得到一些图像并应用对象检测来获得所需的感兴趣区域。
- 在这些 ROI 中的每一个上,我都必须再次应用对象检测,所以我将它作为批处理传递。
第一步,我使用 simpletf.image.non_max_suppression()
后跟tf.gather()
,但我不明白,如何进行第二步。
请参考以下代码片段:
with tf.Session(graph = self.detection_graph) as sess:
# input image tensor
image_tensor1 = self.detection_graph.get_tensor_by_name('import/image_tensor:0')
# boxes, scores and classes for first step
boxesop1 = self.detection_graph.get_tensor_by_name('import/detection_boxes:0')
scoresop1 = self.detection_graph.get_tensor_by_name('import/detection_scores:0')
classesop1 = self.detection_graph.get_tensor_by_name('import/detection_classes:0')
# getting first values, since we are predicting on single image
boxesop1 = boxesop1[0]
classesop1 = classesop1[0]
scoresop1 = scoresop1[0]
# applying NMS for the first step
selected_indices1 = tf.image.non_max_suppression(
boxesop1, scoresop1, 20, iou_threshold = 0.5
)
boxesop1 = tf.gather(boxesop1, selected_indices1)
classesop1 = tf.gather(classesop1, selected_indices1)
scoresop1 = tf.gather(scoresop1, selected_indices1)
# boxes, scores and classes for second step
boxesop2 = self.detection_graph.get_tensor_by_name('import_1/detection_boxes:0')
scoresop2 = self.detection_graph.get_tensor_by_name('import_1/detection_scores:0')
classesop2 = self.detection_graph.get_tensor_by_name('import_1/detection_classes:0')
# applying NMS for the second step
boxesop2, scoresop2, classesop2, valid_detections = tf.image.combined_non_max_suppression(
boxesop2, scoresop2, max_output_size_per_class = 10, max_total_size = 30,
iou_threshold = 0.5
)
# predicting for each images
for imgPath, imgID in img_files:
# reading image data
img = cv2.imread(imgPath)
imageHeight, imageWidth = img.shape[:2]
# Expand dimensions since the model expects images to have shape: [1, None, None, 3]
image_np_expanded = np.expand_dims(img, axis=0)
# Run inference
(boxes1, scores1, classes1, boxes2, scores2, classes2) = sess.run(
[boxesop1, scoresop1, classesop1, boxesop2, scoresop2, classesop2],
feed_dict={image_tensor1: image_np_expanded}
)
但是当我尝试在上面运行时出现以下错误:
Traceback (most recent call last):
File "../env/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1356, in _do_call
return fn(*args)
File "../env/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1341, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
File "../env/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1429, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: boxes must be 4-D[20,300,4]
[[{{node combined_non_max_suppression/CombinedNonMaxSuppression}}]]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/prediction.py", line 159, in predict
feed_dict={image_tensor1: image_np_expanded}
File "../env/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 950, in run
run_metadata_ptr)
File "../env/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1173, in _run
feed_dict_tensor, options, run_metadata)
File "../env/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1350, in _do_run
run_metadata)
File "../env/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1370, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: boxes must be 4-D[20,300,4]
[[node combined_non_max_suppression/CombinedNonMaxSuppression (defined at /home/prediction.py:130) ]]
Errors may have originated from an input operation.
Input Source operations connected to node combined_non_max_suppression/CombinedNonMaxSuppression:
import_1/detection_boxes (defined at /home/prediction.py:94)
Original stack trace for 'combined_non_max_suppression/CombinedNonMaxSuppression':
File "/home/prediction.py", line 130, in predict
iou_threshold = 0.5
File "../env/lib/python3.5/site-packages/tensorflow/python/ops/image_ops_impl.py", line 3707, in combined_non_max_suppression
score_threshold, pad_per_class, clip_boxes)
File "../env/lib/python3.5/site-packages/tensorflow/python/ops/gen_image_ops.py", line 431, in combined_non_max_suppression
clip_boxes=clip_boxes, name=name)
File "../env/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
op_def=op_def)
File "../env/lib/python3.5/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "../env/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 3616, in create_op
op_def=op_def)
File "../env/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2005, in __init__
self._traceback = tf_stack.extract_stack()
如何解决它并将 NMS 应用于 tensorflow 中的批量图像?