0

我是pytorch的新手。我正在尝试为图像数据集创建一个 DataLoader,其中每个图像都有一个相应的基本事实(同名):

root:
--->RGB:
------>img1.png
------>img2.png
------>...
------>imgN.png
--->GT:
------>img1.png
------>img2.png
------>...
------>imgN.png

当我使用根文件夹的路径(包含 RGB 和 GT 文件夹)作为输入时,torchvision.datasets.ImageFolder它会读取所有图像,就好像它们都用于输入(归类为 RGB 和 GT)一样,似乎没有办法对 RGB-GT 图像进行配对。我想将 RGB-GT 图像配对、随机播放,然后将其分成定义大小的批次。如何做呢?任何建议将被认真考虑。谢谢。

4

1 回答 1

1

我认为,好的起点是使用VisionDataset类作为基础。我们将在这里使用的是:DatasetFolder 源代码。所以,我们要创建类似的东西。你会注意到这个类依赖于 datasets.folder模块中的另外两个函数:default_loadermake_dataset

我们不打算修改default_loader,因为它已经很好了,它只是帮助我们加载图像,所以我们将它导入。

但是我们需要一个新make_dataset函数,从根文件夹准备正确的图像对。由于原始make_dataset对图像(更准确地说是图像路径)和它们的根文件夹作为目标类(类索引),我们有一个(path, class_to_idx[target])对列表,但我们需要(rgb_path, gt_path). 这是 new 的代码make_dataset

def make_dataset(root: str) -> list:
    """Reads a directory with data.
    Returns a dataset as a list of tuples of paired image paths: (rgb_path, gt_path)
    """
    dataset = []

    # Our dir names
    rgb_dir = 'RGB'
    gt_dir = 'GT'   

    # Get all the filenames from RGB folder
    rgb_fnames = sorted(os.listdir(os.path.join(root, rgb_dir)))

    # Compare file names from GT folder to file names from RGB:
    for gt_fname in sorted(os.listdir(os.path.join(root, gt_dir))):

            if gt_fname in rgb_fnames:
                # if we have a match - create pair of full path to the corresponding images
                rgb_path = os.path.join(root, rgb_dir, gt_fname)
                gt_path = os.path.join(root, gt_dir, gt_fname)

                item = (rgb_path, gt_path)
                # append to the list dataset
                dataset.append(item)
            else:
                continue

    return dataset

我们现在有什么?让我们将我们的函数与原始函数进行比较:

from torchvision.datasets.folder import make_dataset as make_dataset_original


dataset_original = make_dataset_original(root, {'RGB': 0, 'GT': 1}, extensions='png')
dataset = make_dataset(root)

print('Original make_dataset:')
print(*dataset_original, sep='\n')

print('Our make_dataset:')
print(*dataset, sep='\n')
Original make_dataset:
('./data/GT/img1.png', 1)
('./data/GT/img2.png', 1)
...
('./data/RGB/img1.png', 0)
('./data/RGB/img2.png', 0)
...
Our make_dataset:
('./data/RGB/img1.png', './data/GT/img1.png')
('./data/RGB/img2.png', './data/GT/img2.png')
...

我认为它很好用)是时候创建我们的类数据集了。这里最重要的部分是__getitem__方法,因为它导入图像、应用转换并返回一个张量,可供数据加载器使用。我们需要读取一对图像(rgb 和 gt)并返回一个包含 2 张张量图像的元组:

from torchvision.datasets.folder import default_loader
from torchvision.datasets.vision import VisionDataset


class CustomVisionDataset(VisionDataset):

    def __init__(self,
                 root,
                 loader=default_loader,
                 rgb_transform=None,
                 gt_transform=None):
        super().__init__(root,
                         transform=rgb_transform,
                         target_transform=gt_transform)

        # Prepare dataset
        samples = make_dataset(self.root)

        self.loader = loader
        self.samples = samples
        # list of RGB images
        self.rgb_samples = [s[1] for s in samples]
        # list of GT images
        self.gt_samples = [s[1] for s in samples]

    def __getitem__(self, index):
        """Returns a data sample from our dataset.
        """
        # getting our paths to images
        rgb_path, gt_path = self.samples[index]

        # import each image using loader (by default it's PIL)
        rgb_sample = self.loader(rgb_path)
        gt_sample = self.loader(gt_path)

        # here goes tranforms if needed
        # maybe we need different tranforms for each type of image
        if self.transform is not None:
            rgb_sample = self.transform(rgb_sample)
        if self.target_transform is not None:
            gt_sample = self.target_transform(gt_sample)      

        # now we return the right imported pair of images (tensors)
        return rgb_sample, gt_sample

    def __len__(self):
        return len(self.samples)

让我们测试一下:

from torch.utils.data import DataLoader

from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


bs=4  # batch size
transforms = ToTensor()  # we need this to convert PIL images to Tensor
shuffle = True

dataset = CustomVisionDataset('./data', rgb_transform=transforms, gt_transform=transforms)
dataloader = DataLoader(dataset, batch_size=bs, shuffle=shuffle)

for i, (rgb, gt) in enumerate(dataloader):
    print(f'batch {i+1}:')
    # some plots
    for i in range(bs):
        plt.figure(figsize=(10, 5))
        plt.subplot(221)
        plt.imshow(rgb[i].squeeze().permute(1, 2, 0))
        plt.title(f'RGB img{i+1}')
        plt.subplot(222)
        plt.imshow(gt[i].squeeze().permute(1, 2, 0))
        plt.title(f'GT img{i+1}')
        plt.show()

出去:

batch 1:

一种 b C

...

在这里,您可以找到带有代码和简单虚拟数据集的笔记本。

于 2019-12-24T17:29:22.803 回答