0

我想在 DefaultPredictor () 中推断 Mask2Former 模型,但是当我想预测超过 50 个单词时出现 RAM 溢出。如果我调用 DefaultPredictor(cfg),我的 google collab 崩溃并且我的 RAM 已满。我认为内存泄漏发生在推理函数中,我尝试使用 torch.no_grad() 和 gc.collect() 来优化 RAM 使用以清除缓存缓存,但 RAM 仍然崩溃。如何优化内存使用或重写推理函数?

cfg = get_cfg()
add_maskformer2_config(cfg)

config_file = "Mask2Former/configs/coco/instance-segmentation/maskformer2_R50_bs16_50ep.yaml"

##What for? Why? does not work without it.
cfg.MODEL.RESNETS.STEM_TYPE = "basic"  
cfg.MODEL.RESNETS.RES5_MULTI_GRID = [1, 1, 1]  

cfg.merge_from_file(config_file)
cfg.MODEL.WEIGHTS = "/content/drive/MyDrive/gigaflopps/model_0000499.pth"

cfg.DATASETS.TEST = ("my_dataset_val", )


cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1

cfg.INPUT.IMAGE_SIZE = 512
cfg.INPUT.FORMAT = 'RGB'

cfg.TEST.DETECTIONS_PER_IMAGE = 10

predictor = DefaultPredictor(cfg)
id_image_selected = 10
example = dataset_dicts_val[id_image_selected]
im = cv2.imread(example["file_name"])
outputs = predictor(im)
plt.figure(figsize=(7,7))
v = Visualizer(im[:, :],
              metadata=val_metadata, 
              scale=0.4 )
v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
plt.imshow(v.get_image()[:, :, ::-1])
plt.axis('off')
plt.show()

来自https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/maskformer_model.py#L344的源代码

def instance_inference(self, mask_cls, mask_pred):
        # mask_pred is already processed to have the same shape as original input
        image_size = mask_pred.shape[-2:]

        # [Q, K]
        scores = F.softmax(mask_cls, dim=-1)[:, :-1]
        labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
        # scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
        scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
        labels_per_image = labels[topk_indices]

        topk_indices = topk_indices // self.sem_seg_head.num_classes
        # mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
        mask_pred = mask_pred[topk_indices]

        # if this is panoptic segmentation, we only keep the "thing" classes
        if self.panoptic_on:
            keep = torch.zeros_like(scores_per_image).bool()
            for i, lab in enumerate(labels_per_image):
                keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values()

            scores_per_image = scores_per_image[keep]
            labels_per_image = labels_per_image[keep]
            mask_pred = mask_pred[keep]

        result = Instances(image_size)
        # mask (before sigmoid)
        result.pred_masks = (mask_pred > 0).float()
        result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
        # Uncomment the following to get boxes from masks (this is slow)
        # result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()

        # calculate average mask prob
        mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
        result.scores = scores_per_image * mask_scores_per_image
        result.pred_classes = labels_per_image
        ##try gc.collect()
        return result
4

0 回答 0