1

我正在尝试使用Momentum Optimizer微调 vgg_16模型。为此,我使用了这里的预训练模型。

在微调之前,我从模型中分配变量值如下,

variables_to_restore = slim.get_variables_to_restore(exclude=["vgg_16/fc8"])
init_assign_op, init_feed_dict = slim.assign_from_checkpoint(model_path, variables_to_restore)

请注意,我不排除vgg_16/*/*/Momentum变量。因此我收到一个错误,

ValueError: Checkpoint is missing variable [vgg_16/conv1/conv1_1/weights/Momentum],

正如预期的那样。

我的问题是在排除列表中包含所有 Momentum 变量非常麻烦(示例)。有没有更聪明的方法来排除动量变量?

这很重要,因为对于 resnet 等大型模型,手动输入排除项是不可能的。

先感谢您!

4

1 回答 1

1

您可以使用以下代码解决此问题:

def _init_fn():


    variables_to_restore = []
      for var in slim.get_model_variables():
        excluded = False
        for exclusion in exclusions:
          if var.op.name.startswith(exclusion):
            excluded = True
            break
        if not excluded:
          variables_to_restore.append(var)

      if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
        checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
      else:
        checkpoint_path = FLAGS.checkpoint_path

      tf.logging.info('Fine-tuning from %s' % checkpoint_path)

      return slim.assign_from_checkpoint_fn(
          checkpoint_path,
          variables_to_restore,
    ignore_missing_vars=FLAGS.ignore_missing_vars)

slim.learning.train(init_fn=init_fn,)

于 2018-06-12T07:37:45.777 回答