我正在编写代码来训练 CNN(使用Keras),使用Hyperas进行超参数搜索。为了训练模型,我使用了ImageDataGenerator的flow_from_directory函数。
我阅读了很多在互联网上找到的帖子和文档,但我的代码不起作用。我不明白为什么。
下面是我的代码:
'''
# Installation d'hyperas
!pip install hyperas
# Accès aux fichiers de gDrive
from google.colab import drive
drive.mount('/content/gdrive')
# Copie de la class data_gen.py sur la racine de gColab
!cp '/content/gdrive/My Drive/Deep-
learning/Projets/CNN_dogs&cats/CNN_cats&dogs_2600_hyperopt_gColab.ipynb' '/content/'
# Importation des librairies
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense
from keras import backend as K
import matplotlib.pyplot as plt
import numpy as np
from keras.optimizers import SGD
from hyperopt import Trials, STATUS_OK, tpe
from hyperas import optim
from hyperas.distributions import choice, uniform
# Variables
img_width, img_height = 150, 150
nb_train_samples = 1855
nb_validation_samples = 745
nb_test_samples = 2750
epochs=20
batch_size=16
test_data_dir = '/content/gdrive/My Drive/Deep-learning/Projets/CNN_dogs&cats/data/PetImages/test'
def data():
train_data_dir = '/content/gdrive/My Drive/Deep-learning/Projets/CNN_dogs&cats/data/PetImages_2600/train'
validation_data_dir = '/content/gdrive/My Drive/Deep-
learning/Projets/CNN_dogs&cats/data/PetImages_2600/validation'
# Instanciation des générateurs d'images train
train_datagen = ImageDataGenerator(rescale=1. / 255)
validation_datagen = ImageDataGenerator(rescale=1. / 255)
# Instanciation des générateurs
train_generator = train_datagen.flow_from_directory(train_data_dir, target_size=(img_width, img_height), batch_size=batch_size, class_mode='binary')
validation_generator = validation_datagen.flow_from_directory(validation_data_dir, target_size=(img_width, img_height), batch_size=batch_size, class_mode='binary')
return train_generator, validation_generator
def model(train_generator, validation_generator):
# Vérification du format des images
if K.image_data_format() == 'channels_first':
input_shape = (3, img_width, img_height)
else:
input_shape = (img_width, img_height, 3)
# Construction du modèle
model = Sequential()
model.add(Conv2D(32, (3, 3), input_shape=input_shape))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dropout({{choice([0.1, 0.2, 0.3, 0.4, 0.5])}}))
model.add(Dense(1))
model.add(Activation({{choice(['relu', 'sigmoid'])}}))
model.compile(loss='binary_crossentropy', optimizer=SGD(lr={{uniform([0, 1])}}), metrics=['accuracy'])
# Lancement de la phase d'apprentissage sur la base de train
model.fit_generator(train_generator, steps_per_epoch=nb_train_samples // batch_size, epochs=epochs, validation_data=validation_generator, validation_steps=nb_validation_samples // batch_size)
score, acc = model.evaluate_generator(generator=validation_generator, steps=nb_validation_samples // batch_size)
return {'loss': -acc, 'status': STATUS_OK, 'model': model}
if __name__ == '__main__':
train_generator, validation_generator = data()
best_run, best_model = optim.minimize(model=model, data=data, algo=tpe.suggest, max_evals=10, notebook_name='CNN_cats&dogs_2600_hyperopt_gColab', trials=Trials())
print('Evaluation of best performing model:')
print(best_model.evaluate(validation_generator))
'''
在线上 :
best_run, best_model = optim.minimize(model=model, data=data, algo=tpe.suggest, max_evals=10, notebook_name='CNN_cats&dogs_2600_hyperopt_gColab', trials=Trials())
我有这个错误信息:
/usr/local/lib/python3.6/dist-packages/hyperas/optim.py 在retrieve_data_string(数据,详细)
219 data_string = inspect.getsource(data)
220 first_line = data_string.split("\n")[0]
---> 221 indent_length = len(determine_indent(data_string))
222 data_string = data_string.replace(first_line, "")
223 r = re.compile(r'^\s*return.*'
TypeError: object of type 'NoneType' has no len()