我试图优化用于图像分类的 keras CNN 的超参数。我考虑使用 sklearn 和 talos 优化器 ( https://github.com/autonomio/talos ) 的网格搜索。我克服了从 flow_from_directory 制作 x 和 y 的基本困难(下面的代码),但是......它仍然不起作用!任何想法?也许有人面临同样的问题。
def talos_model(train_flow, validation_flow, nb_train_samples, nb_validation_samples, params):
model = Sequential()
model.add(Conv2D(6,(5,5),activation="relu",padding="same",
input_shape=(img_width, img_height, 3)))
model.add(MaxPooling2D((2,2)))
model.add(Dropout(params['dropout']))
model.add(Conv2D(16,(5,5),activation="relu"))
model.add(MaxPooling2D((2,2)))
model.add(Dropout(params['dropout']))
model.add(Flatten())
model.add(Dense(120, activation='relu', kernel_initializer=params['kernel_initializer']))
model.add(Dropout(params['dropout']))
model.add(Dense(84, activation='relu', kernel_initializer=params['kernel_initializer']))
model.add(Dropout(params['dropout']))
model.add(Dense(10, activation='softmax'))
model.compile(loss=params['loss'],
optimizer=params['optimizer'],
metrics=['categorical_accuracy'])
checkpointer = ModelCheckpoint(filepath='talos_cnn.h5py',
monitor='val_categorical_accuracy', save_best_only=True)
history = model.fit_generator(generator=train_flow,
samples_per_epoch=nb_train_samples,
validation_data=validation_flow,
nb_val_samples=nb_validation_samples,
callbacks=[checkpointer],
verbose=1,
epochs=params['epochs'])
return history, model
train_generator = ImageDataGenerator(rescale=1/255)
validation_generator = ImageDataGenerator(rescale=1/255)
# Retrieve images and their classes for train and validation sets
train_flow = train_generator.flow_from_directory(directory=train_data_dir,
batch_size=batch_size,
target_size(img_height,img_width))
validation_flow = validation_generator.flow_from_directory(directory=validation_data_dir,
batch_size=batch_size,
target_size=(img_height,img_width),
shuffle = False)
#here I make x and y for talos
(X_train, Y_train) = train_flow.next()
#starting an experiment with talos
t = ta.Scan(x=X_train,
y=Y_train,
model=talos_model,
params=params,
dataset_name='landmarks',
experiment_no='1')
最后一行出现错误:
具有多个元素的数组的真值是不明确的。使用 a.any() 或 a.all()