我很难使用 scikit learn 在 Keras 中实现网格搜索。基于本教程,我编写了以下代码:
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import GridSearchCV
def create_model():
model = Sequential()
model.add(Dense(100, input_shape=(max_len, len(alphabet)), kernel_regularizer=regularizers.l2(0.001)))
model.add(Dropout(0.85))
model.add(LSTM(100, input_shape=(100,)))
model.add(Dropout(0.85))
model.add(Dense(num_output_classes, activation='softmax'))
adam = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, decay=1e-6)
model.compile(loss='categorical_crossentropy',
optimizer=adam,
metrics=['accuracy'])
return model
seed = 7
np.random.seed(seed)
model = KerasClassifier(build_fn=create_model, epochs=10, verbose=0)
batch_size = [10,20]
param_grid = dict(batch_size=batch_size)
grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=-1)
grid_result = grid.fit(train_data_reduced, train_labels_reduced)
print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
means = grid_result.cv_results_['mean_test_score']
stds = grid_result.cv_results_['std_test_score']
params = grid_result.cv_results_['params']
for mean, stdev, param in zip(means, stds, params):
print("%f (%f) with: %r" % (mean, stdev, param))
它没有给我任何错误消息,但它只是永远运行下去,而不会打印出任何东西。我特意用很少的 epoch、很少的训练示例和很少的超参数来搜索它。如果没有网格搜索,一个 epoch 的运行速度非常快,所以我认为我不需要给它更多的时间。它根本没有做任何事情。
谁能指出我错过了什么?
非常感谢!