4

我在下面有代码,

train_datagen = ImageDataGenerator(
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
        'data/train',
        target_size=(150, 150),
        batch_size=32,
        class_mode='binary')
validation_generator = test_datagen.flow_from_directory(
        'data/validation',
        target_size=(150, 150),
        batch_size=32,
        class_mode='binary')

现在model.fit_generator定义如下:

model.fit_generator(
        train_generator,
        steps_per_epoch=2000,
        epochs=50,
        validation_data=validation_generator,
        validation_steps=800)

现在已弃用,在这种情况下model.fit_generator更改为的正确方法是什么model.fit_generatormodel.fit

4

2 回答 2

4

您只需更改model.fit_generator()model.fit().

从 TensorFlow 2.1 开始,model.fit()也接受生成器作为输入。就如此容易。

来自 TensorFlow 的官方文档:

警告:此功能已弃用。它将在未来的版本中删除。更新说明:请使用支持生成器的 Model.fit。

于 2020-03-08T09:08:22.457 回答
3

旧训练 = model.fit_generator(generator=train_generator, steps_per_epoch=2048 //36,epochs=10,validation_data=validation_generator,validation_steps=832//16)

摆脱 'generator=' 新训练 = model.fit(train_generator, steps_per_epoch=2048 // 128,epochs=10,validation_data=validation_generator,validation_steps=832//16)

于 2020-04-03T07:32:33.507 回答