1

这是我的问题:如何在我的代码中使用提前停止?我应该把它放在哪个部分?

callbacks=[tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10,mode="auto")]

我的代码:


numpy.random import seed
seed(1)

def create_model(optimizer='rmsprop'):
    model = Sequential()
    model.add(LSTM(50, activation='relu', return_sequences=True))
    model.add(LSTM(50, activation='relu'))
    model.add(Dense(1))

    model.compile(loss='mse',optimizer = optimizer)

    return model

clf = KerasRegressor(build_fn=create_model,epochs = 500,callbacks=[tf.keras.callbacks.EarlyStopping( patience=10)])

param_grid = {
'clf__optimizer' : ['adam','rmsprop'],
'clf__batch_size' : [500,45,77]
}

pipeline = Pipeline([
('clf',clf)
])

from sklearn.model_selection import TimeSeriesSplit, GridSearchCV

tscv = TimeSeriesSplit(n_splits=5)

grid = GridSearchCV(pipeline, cv=tscv,param_grid=param_grid,return_train_score=True,verbose=10,
scoring = 'neg_mean_squared_error')

grid.fit(Xtrain2,ytrain.values)

grid.cv_results_

我把回调放在'grid.fit'和'param_grid'中,但我出错了!!!

4

2 回答 2

0

您需要直接使用model.fit()函数训练 keras 模型,您会看到它允许您传入回调参数

于 2020-06-22T12:11:24.727 回答
0

回调在KerasRegressor.fit( docs ) 中指定,并GridSearchCV.fit接受fit_params关键字参数。来自文档

** fit_params : str -> object 的字典

传递给fit估算器方法的参数

所以类似的东西

grid.fit(Xtrain2, ytrain.values, callbacks=[...])

通常应该可以工作。在您的情况下,因为您已经嵌入到管道中,所以您需要另外限定模型的范围,如

grid.fit(Xtrain2, ytrain.values, clf__callbacks=[...])

另请参阅我可以向 KerasClassifier 发送回调吗?,尽管该问题还有很多其他问题。

于 2021-05-17T16:56:31.233 回答