1

我很难使用 scikit learn 在 Keras 中实现网格搜索。基于本教程,我编写了以下代码:

from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import GridSearchCV

    def create_model():
            model = Sequential()
            model.add(Dense(100, input_shape=(max_len, len(alphabet)), kernel_regularizer=regularizers.l2(0.001)))
            model.add(Dropout(0.85))
            model.add(LSTM(100, input_shape=(100,))) 
            model.add(Dropout(0.85))
            model.add(Dense(num_output_classes, activation='softmax'))

            adam = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, decay=1e-6)

            model.compile(loss='categorical_crossentropy',
                      optimizer=adam,
                      metrics=['accuracy']) 

            return model

    seed = 7
    np.random.seed(seed)

    model = KerasClassifier(build_fn=create_model, epochs=10, verbose=0)

    batch_size = [10,20]
    param_grid = dict(batch_size=batch_size)
    grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=-1)
    grid_result = grid.fit(train_data_reduced, train_labels_reduced)

    print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
    means = grid_result.cv_results_['mean_test_score']
    stds = grid_result.cv_results_['std_test_score']
    params = grid_result.cv_results_['params']
    for mean, stdev, param in zip(means, stds, params):
        print("%f (%f) with: %r" % (mean, stdev, param))

它没有给我任何错误消息,但它只是永远运行下去,而不会打印出任何东西。我特意用很少的 epoch、很少的训练示例和很少的超参数来搜索它。如果没有网格搜索,一个 epoch 的运行速度非常快,所以我认为我不需要给它更多的时间。它根本没有做任何事情。

谁能指出我错过了什么?

非常感谢!

4

1 回答 1

1

我有同样的问题。

n_jobs=-1从您的参数列表中删除可能会有所帮助!也尽量不要做热编码。

于 2018-06-08T16:53:42.110 回答