我想在 DefaultPredictor () 中推断 Mask2Former 模型,但是当我想预测超过 50 个单词时出现 RAM 溢出。如果我调用 DefaultPredictor(cfg),我的 google collab 崩溃并且我的 RAM 已满。我认为内存泄漏发生在推理函数中,我尝试使用 torch.no_grad() 和 gc.collect() 来优化 RAM 使用以清除缓存缓存,但 RAM 仍然崩溃。如何优化内存使用或重写推理函数?
cfg = get_cfg()
add_maskformer2_config(cfg)
config_file = "Mask2Former/configs/coco/instance-segmentation/maskformer2_R50_bs16_50ep.yaml"
##What for? Why? does not work without it.
cfg.MODEL.RESNETS.STEM_TYPE = "basic"
cfg.MODEL.RESNETS.RES5_MULTI_GRID = [1, 1, 1]
cfg.merge_from_file(config_file)
cfg.MODEL.WEIGHTS = "/content/drive/MyDrive/gigaflopps/model_0000499.pth"
cfg.DATASETS.TEST = ("my_dataset_val", )
cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 1
cfg.INPUT.IMAGE_SIZE = 512
cfg.INPUT.FORMAT = 'RGB'
cfg.TEST.DETECTIONS_PER_IMAGE = 10
predictor = DefaultPredictor(cfg)
id_image_selected = 10
example = dataset_dicts_val[id_image_selected]
im = cv2.imread(example["file_name"])
outputs = predictor(im)
plt.figure(figsize=(7,7))
v = Visualizer(im[:, :],
metadata=val_metadata,
scale=0.4 )
v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
plt.imshow(v.get_image()[:, :, ::-1])
plt.axis('off')
plt.show()
来自https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/maskformer_model.py#L344的源代码
def instance_inference(self, mask_cls, mask_pred):
# mask_pred is already processed to have the same shape as original input
image_size = mask_pred.shape[-2:]
# [Q, K]
scores = F.softmax(mask_cls, dim=-1)[:, :-1]
labels = torch.arange(self.sem_seg_head.num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
# scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.num_queries, sorted=False)
scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
labels_per_image = labels[topk_indices]
topk_indices = topk_indices // self.sem_seg_head.num_classes
# mask_pred = mask_pred.unsqueeze(1).repeat(1, self.sem_seg_head.num_classes, 1).flatten(0, 1)
mask_pred = mask_pred[topk_indices]
# if this is panoptic segmentation, we only keep the "thing" classes
if self.panoptic_on:
keep = torch.zeros_like(scores_per_image).bool()
for i, lab in enumerate(labels_per_image):
keep[i] = lab in self.metadata.thing_dataset_id_to_contiguous_id.values()
scores_per_image = scores_per_image[keep]
labels_per_image = labels_per_image[keep]
mask_pred = mask_pred[keep]
result = Instances(image_size)
# mask (before sigmoid)
result.pred_masks = (mask_pred > 0).float()
result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
# Uncomment the following to get boxes from masks (this is slow)
# result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes()
# calculate average mask prob
mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6)
result.scores = scores_per_image * mask_scores_per_image
result.pred_classes = labels_per_image
##try gc.collect()
return result