0

我正在尝试训练用于图像修复的自动编码器,其中输入图像是损坏的图像,输出图像是基本事实。

使用的数据集组织为:

/Dataset
    /corrupted
         img1.jpg
         img2.jpg
            .
            .
    /groundTruth
         img1.jpg
         img2.jpg
            .
            .

使用的图像数量相对较大。如何使用 Keras 图像数据生成器将数据提供给模型?我检查了 flow_from_directory 方法,但找不到合适的 class_mode 来使用(“损坏”文件夹中的每个图像都映射到“groundTruth”文件夹中具有相同名称的图像)

4

1 回答 1

0

如果没有提供您需要的功能的预建图像数据生成器,您可以创建自己的自定义数据生成器

为此,您必须通过子类化创建新的数据生成器类tf.keras.utils.Sequence。您需要在新类中实现__getitem____len__方法。__len__必须返回数据集中的批次数,同时__getitem__必须将单个批次中的元素作为元组返回。

你可以在这里阅读官方文档。下面是一个代码示例:

from skimage.io import imread
from skimage.transform import resize
import numpy as np
import math

# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.

class CIFAR10Sequence(Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return math.ceil(len(self.x) / self.batch_size)

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) *
        self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) *
        self.batch_size]

        return np.array([
            resize(imread(file_name), (200, 200))
               for file_name in batch_x]), np.array(batch_y)

希望答案有帮助!

于 2022-03-04T23:08:04.307 回答