我想使用其中一种图像增强技术(例如旋转或水平翻转)并将其应用于 CIFAR-10 数据集的一些图像并在 PyTorch 中绘制它们。
我知道我们可以使用以下代码来增强图像:
from torchvision import models, datasets, transforms
from torchvision.datasets import CIFAR10
data_transforms = transforms.Compose([
# add augmentations
transforms.RandomHorizontalFlip(p=0.5),
# The output of torchvision datasets are PILImage images of range [0, 1].
# We transform them to Tensors of normalized range [-1, 1]
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
然后当我想加载 Cifar10 数据集时,我使用了上面的转换:
train_set = CIFAR10(
root='./data/',
train=True,
download=True,
transform=data_transforms['train'])
据我所知,当使用此代码时,所有 CIFAR10 数据集都会被转换。
问题
我的问题是如何对数据集中的某些图像使用数据转换或增强技术并绘制它们?例如 10 个图像及其增强图像。