如果没有提供您需要的功能的预建图像数据生成器,您可以创建自己的自定义数据生成器。
为此,您必须通过子类化创建新的数据生成器类tf.keras.utils.Sequence
。您需要在新类中实现__getitem__
和__len__
方法。__len__
必须返回数据集中的批次数,同时__getitem__
必须将单个批次中的元素作为元组返回。
你可以在这里阅读官方文档。下面是一个代码示例:
from skimage.io import imread
from skimage.transform import resize
import numpy as np
import math
# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.
class CIFAR10Sequence(Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return math.ceil(len(self.x) / self.batch_size)
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) *
self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) *
self.batch_size]
return np.array([
resize(imread(file_name), (200, 200))
for file_name in batch_x]), np.array(batch_y)
希望答案有帮助!