我正在尝试获取由 Pytorches DefaultPredictor 生成的蒙版内像素的坐标,以便稍后获取多边形角并在我的应用程序中使用它。
但是,DefaultPredictor 产生了 pred_masks 的张量,格式如下: [False, False ... False], ... [False, False, .. False] 其中每个单独列表的长度是图像的长度,并且总列表的数量是图像的高度。
现在,由于我需要获取掩码内的像素坐标,因此简单的解决方案似乎是循环遍历 pred_masks,检查值并 if == "True" 创建这些元组并将它们添加到列表中。然而,当我们谈论宽度 x 高度约为 3200 x 1600 的图像时,这是一个相对较慢的过程(循环单个 3200x1600 大约需要 4 秒,但是因为有很多对象需要我进行推理最后 - 这最终会变得非常缓慢)。
使用pytorch(detectron2)模型获取检测到的对象的坐标(掩码)的更智能方法是什么?
请在下面找到我的代码以供参考:
from __future__ import print_function
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog
from detectron2.data.datasets import register_coco_instances
import cv2
import time
# get image
start = time.time()
im = cv2.imread("inputImage.jpg")
# Create config
cfg = get_cfg()
cfg.merge_from_file("detectron2_repo/configs/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # Set threshold for this model
cfg.MODEL.WEIGHTS = "model_final.pth" # Set path model .pth
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
cfg.MODEL.DEVICE='cpu'
register_coco_instances("dataset_test",{},"testval.json","Images_path")
test_metadata = MetadataCatalog.get("dataset_test")
# Create predictor
predictor = DefaultPredictor(cfg)
# Make prediction
outputs = predictor(im)
#Loop through the pred_masks and check which ones are equal to TRUE, if equal, add the pixel values to the true_cords_list
outputnump = outputs["instances"].pred_masks.numpy()
true_cords_list = []
x_length = range(len(outputnump[0][0]))
#y kordinaat on range number
for y_cord in range(len(outputnump[0])):
#x cord
for x_cord in x_length:
if str(outputnump[0][y_cord][x_cord]) == "True":
inputcoords = (x_cord,y_cord)
true_cords_list.append(inputcoords)
print(str(true_cords_list))
end = time.time()
print(f"Runtime of the program is {end - start}") # 14.29468035697937
//
编辑:在将 for 循环部分更改为压缩后 - 我设法将 for 循环的运行时间减少了 ~3 倍 - 但是,理想情况下,如果可能的话,我希望从预测器本身接收它。
y_length = len(outputnump[0])
x_length = len(outputnump[0][0])
true_cords_list = []
for y_cord in range(y_length):
x_cords = list(compress(range(x_length), outputnump[0][y_cord]))
if x_cords:
for x_cord in x_cords:
inputcoords = (x_cord,y_cord)
true_cords_list.append(inputcoords)