我想对我的 Keras 模型执行超参数优化。问题是数据集很大,通常在训练中我使用fit_generator
从磁盘批量加载数据,但是像SKlearn Gridsearch,Talos等常见的包只支持fit
方法。
我尝试使用以下方法将整个数据加载到内存中:
train_generator = train_datagen.flow_from_directory(
original_dir,
target_size=(img_height, img_width),
batch_size=train_nb,
class_mode='categorical')
X_train,y_train = train_generator.next()
但是在执行网格搜索时,操作系统会因为内存使用量大而将其杀死。我还尝试将我的数据集欠采样到仅 25%,但它仍然太大。
有人和我有同样的经历吗?您能否分享您对大型数据集执行超参数优化的策略?
根据@dennis-ec 的回答,我尝试在此处遵循 SkOpt 的教程:http ://slashtutorial.com/ai/tensorflow/19_hyper-parameters/这是一个非常全面的教程