1

我想使用其中一种图像增强技术(例如旋转或水平翻转)并将其应用于 CIFAR-10 数据集的一些图像并在 PyTorch 中绘制它们。

我知道我们可以使用以下代码来增强图像:

from torchvision import models, datasets, transforms
from torchvision.datasets import CIFAR10

data_transforms = transforms.Compose([
        # add augmentations
        transforms.RandomHorizontalFlip(p=0.5),
        # The output of torchvision datasets are PILImage images of range [0, 1].
        # We transform them to Tensors of normalized range [-1, 1]
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

然后当我想加载 Cifar10 数据集时,我使用了上面的转换:

train_set = CIFAR10(
    root='./data/',
    train=True,
    download=True,
    transform=data_transforms['train'])

据我所知,当使用此代码时,所有 CIFAR10 数据集都会被转换。

问题

我的问题是如何对数据集中的某些图像使用数据转换或增强技术并绘制它们?例如 10 个图像及其增强图像。

4

1 回答 1

1

使用此代码时,将转换所有 CIFAR10 数据集

__getitem__实际上,只有当用户通过函数或数据加载器获取数据集中的图像时,才会调用转换管道。所以在这个时间点,train_set不包含增强图像,它们是动态转换的。


您将需要构建另一个没有扩充的数据集。

>>> non_augmented = CIFAR10(
...     root='./data/',
...     train=True,
...     download=True)

>>> train_set = CIFAR10(
...     root='./data/',
...     train=True,
...     download=True,
...     transform=data_transforms)

将一些图像堆叠在一起:

>>> imgs = torch.stack((*[non_augmented[i][0] for i in range(10)],
                        *[train_set[i][0] for i in range(10)]))

>>> imgs.shape
torch.Size([20, 3, 32, 32])

然后torchvision.utils.make_grid可用于创建所需的布局:

>>> grid = torchvision.utils.make_grid(imgs, nrow=10)

你有它!

>>> transforms.ToPILImage()(grid)

在此处输入图像描述

于 2021-09-09T06:55:06.530 回答