1

我正在使用预训练的 VGG-16 模型进行图像分类。我正在添加自定义最后一层,因为我的分类类的数量是 10。我正在为模型训练 200 个时期。

我的问题是:如果我在某个时期随机停止(通过关闭 python 窗口)训练,有什么办法吗?比如说没有。50 和从那里恢复?我已经阅读了有关保存和重新加载模型的信息,但我的理解是这仅适用于我们的自定义模型,而不适用于 VGG-16 等预训练模型。

4

2 回答 2

4

您可以使用ModelCheckpoint回调定期保存模型。要使用它,请将callbacks参数传递给该fit方法:

from keras.callbacks import ModelCheckpoint
checkpointer = ModelCheckpoint(filepath='model-{epoch:02d}.hdf5', ...)
model.fit(..., callbacks=[checkpointer])

然后,稍后您可以加载最后保存的模型。有关此回调的更多自定义,请查看文档。

于 2018-08-24T17:02:29.820 回答
0

这是一个自定义版本的ModelCheckpoint,我用它来从给定的纪元gist恢复训练。它将epoch和其他日志保存到相应的JSON文件中,它还会在开始时检查是否恢复训练。您需要调用get_last_epoch并设置initial_epoch才能model.fit从那个时代恢复。

import json

class StatefulCheckpoint(ModelCheckpoint):
  """Save extra checkpoint data to resume training."""
  def __init__(self, weight_file, state_file=None, **kwargs):
    """Save the state (epoch etc.) along side weights."""
    super().__init__(weight_file, **kwargs)
    self.state_f = state_file
    self.state = dict()
    if self.state_f:
      # Load the last state if any
      try:
        with open(self.state_f, 'r') as f:
          self.state = json.load(f)
        self.best = self.state['best']
      except Exception as e: # pylint: disable=broad-except
        print("Skipping last state:", e)

  def on_train_begin(self, logs=None):
    prefix = "Resuming" if self.state else "Starting"
    print("{} training...".format(prefix))

  def on_epoch_end(self, epoch, logs=None):
    """Saves training state as well as weights."""
    super().on_epoch_end(epoch, logs)
    if self.state_f:
      state = {'epoch': epoch+1, 'best': self.best}
      state.update(logs)
      state.update(self.params)
      with open(self.state_f, 'w') as f:
        json.dump(state, f)

  def get_last_epoch(self, initial_epoch=0):
    """Return last saved epoch if any, or return default argument."""
    return self.state.get('epoch', initial_epoch)
于 2018-08-24T17:09:20.727 回答