我在 keras 中使用了自定义图像生成器的模板,这样我就可以使用 hdf5 文件作为输入。最初,代码给出了“形状”错误,所以我只from tensorflow.python.keras.utils.data_utils import Sequence
关注了这篇文章。现在我以这种形式使用它,你也可以在我的colab notebook中看到:
from numpy.random import uniform, randint
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
import numpy as np
from tensorflow.python.keras.utils.data_utils import Sequence
class CustomImagesGenerator(Sequence):
def __init__(self, x, zoom_range, shear_range, rescale, horizontal_flip, batch_size):
self.x = x
self.zoom_range = zoom_range
self.shear_range = shear_range
self.rescale = rescale
self.horizontal_flip = horizontal_flip
self.batch_size = batch_size
self.__img_gen = ImageDataGenerator()
self.__batch_index = 0
def __len__(self):
# steps_per_epoch, if unspecified, will use the len(generator) as a number of steps.
# hence this
return np.floor(self.x.shape[0]/self.batch_size)
# @property
# def shape(self):
# return self.x.shape
def next(self):
return self.__next__()
def __next__(self):
start = self.__batch_index*self.batch_size
stop = start + self.batch_size
self.__batch_index += 1
if stop > len(self.x):
raise StopIteration
transformed = np.array(self.x[start:stop]) # loads from hdf5
for i in range(len(transformed)):
zoom = uniform(self.zoom_range[0], self.zoom_range[1])
transformations = {
'zx': zoom,
'zy': zoom,
'shear': uniform(-self.shear_range, self.shear_range),
'flip_horizontal': self.horizontal_flip and bool(randint(0,2))
}
transformed[i] = self.__img_gen.apply_transform(transformed[i], transformations)
import pdb;pdb.set_trace()
return transformed * self.rescale
我用以下方法调用生成器:
import h5py
import tables
in_hdf5_file = tables.open_file("gdrive/My Drive/Colab Notebooks/dataset.hdf5", mode='r')
images = in_hdf5_file.root.train_img
my_gen = CustomImagesGenerator(
images,
zoom_range=[0.8, 1],
batch_size=32,
shear_range=6,
rescale=1./255,
horizontal_flip=False
)
classifier.fit_generator(my_gen, steps_per_epoch=100, epochs=1, verbose=1)
导入Sequence
解决了“形状”错误,但现在我收到错误:
线程 Thread-5 中的异常:回溯(最后一次调用):
文件“/usr/lib/python3.6/threading.py”,第 916 行,在 _bootstrap_inner self.run() 文件“/usr/lib/python3. 6/threading.py”,第 864 行,在运行 self._target(*self._args, **self._kwargs) 文件“/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/ utils/data_utils.py",第 742 行,在 _run sequence = list(range(len(self.sequence))) 类型错误:'numpy.float64' 对象不能解释为整数
我该如何解决这个问题?我怀疑这可能又是 keras 软件包中的冲突,并且不知道如何解决它。