1

我想展示一些增强训练图像的样本。

我的转换包括这样的标准 ImageNet transforms.Normalize

train_transforms = transforms.Compose([transforms.RandomRotation(30),
                                       transforms.RandomResizedCrop(224),
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.485, 0.456, 0.406],
                                                            [0.229, 0.224, 0.225])])

但是,由于Normalise,图像以奇怪的颜色显示。

这个答案说我需要访问原始图像,这在加载时应用转换时很困难:

image_datasets['train'] = datasets.ImageFolder(train_dir, transform=train_transforms)

在使用标准化图像进行计算时,我将如何以通常的颜色显示一些示例增强图像?

4

4 回答 4

3

我建议两种选择:

  1. 创建一个单独的“转换”阶段,显示图像并将其进一步传递而无需更改。免费的好处是您可以在转换列表的任何阶段插入。
    import cv2
    import numpy as np
    def TransformShow(name="img", wait=100):
        def transform_show(img):
            cv2.imshow(name, np.array(img))
            cv2.waitKey(wait)
            return img
        return transform_show

在之前插入这个“变压器” ToTensor()

                                       transforms.RandomHorizontalFlip(),
                                       TransformShow("window_name", delay_in_ms),
                                       transforms.ToTensor(),

使用零delay_in_ms等待按键。

我在这里使用OpenCV来显示图像。也可以只用 Pillow/PIL 来完成,但我不喜欢它的处理方式。

  1. 撤消标准化并显示图像。
def show_image(img, name="img", wait=100):
    mean = np.array([0.485, 0.456, 0.406])
    std =  np.array([0.229, 0.224, 0.225])
    cv2.imshow(name, img.cpu().numpy().transpose((1,2,0)) * std + mean)
    cv2.waitKey(wait)

然后将其称为

        show_image(data[0], "unaug", 1)
  1. 最后一种方法可以用快速的两线近似,但颜色有些失真:
    cv2.imshow("approx", data[0].cpu().numpy().transpose((1,2,0)) * 0.225 + 0.45)
    cv2.waitKey(10)
于 2019-03-15T18:47:55.130 回答
1

我已经面临同样的问题。我的解决方案是创建一个不同torch.Dataset的数据增强但没有规范化。

在这里,我创建了Dataset. 在这里,我有一个实现增强的类。我有两个成员:self.tf_augmentself.tf_transform。前者仅应用数据增强,而后者应用数据增强和规范化。

于 2019-03-15T11:34:32.617 回答
0

只需撤消归一化操作,即乘以标准偏差并添加平均值。

请看pytorch教程中的imshow方法: https ://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#visualize-a-few-images

于 2019-03-16T06:04:17.270 回答
0

为了回答我自己的问题,我想出了以下内容:

示例输出

# Undo transforms.Normalize
def denormalise(image):
    image = image.numpy().transpose(1, 2, 0)  # PIL images have channel last
    mean = [0.485, 0.456, 0.406]
    stdd = [0.229, 0.224, 0.225]
    image = (image * stdd + mean).clip(0, 1)
    return image


example_rows = 2
example_cols = 5

sampler = torch.utils.data.RandomSampler(image_datasets['train'],
                                         num_samples=example_rows * example_cols)

# Get a batch of images and labels  
images, indices = next(iter(sampler)) 

plt.rcParams['figure.dpi'] = 120  # Increase size of pyplot plots

# Show a grid of example images    
fig, axes = plt.subplots(example_rows, example_cols, figsize=(9, 5)) #  sharex=True, sharey=True)
axes = axes.flatten()
for ax, image, index in zip(axes, images, indices):
    ax.imshow(denormalise(image))
    ax.set_axis_off()
    ax.set_title(class_names[index], fontsize=7)

fig.subplots_adjust(wspace=0.02, hspace=0)
fig.suptitle('Augmented training set images', fontsize=20)
plt.show()

这是基于PyTorch 的迁移学习教程的代码,但在每张图像上方显示标题,通常看起来更好。

于 2019-03-18T11:37:32.947 回答