1

我正在尝试使用 COCO 2014 数据在 PyTorch 中进行语义分割训练。我有一个带有交叉熵损失函数的 PSPNet 模型,该模型在 2012 年的 PASCAL VOC 数据集上完美运行。现在我正在尝试使用一部分 COCO 图片来执行相同的过程。但是 Coco 有 json 数据而不是 .png 图像用于注释,我不得不以某种方式将一个转换为另一个。我注意到 cocotools 中有 annToMask,但我无法安静地弄清楚如何在我的情况下使用该功能。这就是我的数据加载器的拉取项目的样子

def pull_item(self, index):

        I DONT KNOW WHAT TO DO HERE

        raw_img = self.transform(raw_img)
        anns_img = self.transform(anns_img)

        return raw_img, anns_img

下面是我使用来自数据加载器的数据的训练函数的样子。

 for images, labels in dataloaders_dict[phase]:

                images = images.to(device)

                labels = torch.squeeze(labels)
                labels = labels.to(device)

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = net(images)

                    loss = criterion(outputs, labels.long())
4

2 回答 2

2

我一直致力于使用 PyCOCO 为 COCO 数据集创建数据生成器,我认为我的经验可以帮助你。我在媒体上的帖子记录了从开始到结束的整个过程,包括创建面具。

但是,需要注意的是,我使用的是 Tensorflow Keras 而不是 pytorch。但是逻辑流程应该大致相同,所以我相信你可以从中取回一些有用的东西。

于 2020-05-06T07:50:35.903 回答
0

感谢上面的答案,我能够创建这个:

class ImageData(Dataset):
    def __init__(
        self, 
        annotations: COCO, 
        img_ids: List[int], 
        cat_ids: List[int], 
        root_path: Path, 
        transform: Optional[Callable]=None
    ) -> None:
        super().__init__()
        self.annotations = annotations
        self.img_data = annotations.loadImgs(img_ids)
        self.cat_ids = cat_ids
        self.files = [str(root_path / img["file_name"]) for img in self.img_data]
        self.transform = transform
        
    def __len__(self) -> int:
        return len(self.files)
    
    def __getitem__(self, i: int) -> Tuple[torch.Tensor, torch.LongTensor]:
        ann_ids = self.annotations.getAnnIds(
            imgIds=self.img_data[i]['id'], 
            catIds=self.cat_ids, 
            iscrowd=None
        )
        anns = self.annotations.loadAnns(ann_ids)
        mask = torch.LongTensor(np.max(np.stack([self.annotations.annToMask(ann) * ann["category_id"] 
                                                 for ann in anns]), axis=0)).unsqueeze(0)
        
        img = io.read_image(self.files[i])
        if img.shape[0] == 1:
            img = torch.cat([img]*3)
        
        if self.transform is not None:
            return self.transform(img, mask)
        
        return img, mask

完整的帖子可以在这个kaggle 内核中找到。

于 2021-08-21T08:18:38.243 回答