1

我按照本网站的说明训练了一个具有 12 个样式图像的模型: https ://github.com/tensorflow/magenta/tree/master/magenta/models/image_stylization

但是,检查点输出非常大: model.ckpt-0.data-00000-of-00001 (574MB)

他们提供的预训练模型,例如 Monet.ckpt 和 Varied.ckpt,分别只有 6.8MB 和 7.1MB。如何将我的模型从 574MB 显着减少到像预训练模型一样的小尺寸?

代码在这里:

    """Trains the N-styles style transfer model."""

    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function

    import ast
    import os

    # internal imports
    import tensorflow as tf
    from tensorflow.core.protobuf import saver_pb2

    from magenta.models.image_stylization import image_utils
    from magenta.models.image_stylization import learning
    from magenta.models.image_stylization import model
    from magenta.models.image_stylization import vgg

    slim = tf.contrib.slim

    DEFAULT_CONTENT_WEIGHTS = '{"vgg_16/conv3": 1.0}'
    DEFAULT_STYLE_WEIGHTS = ('{"vgg_16/conv1": 1e-4, "vgg_16/conv2": 1e-4,'
                     ' "vgg_16/conv3": 1e-4, "vgg_16/conv4": 1e-4}')

    flags = tf.app.flags
    flags.DEFINE_float('clip_gradient_norm', 0, 'Clip gradients to this norm')
    flags.DEFINE_float('learning_rate', 1e-3, 'Learning rate')
    flags.DEFINE_integer('batch_size', 16, 'Batch size.')
    flags.DEFINE_integer('image_size', 256, 'Image size.')
    flags.DEFINE_integer('ps_tasks', 0,
                 'Number of parameter servers. If 0, parameters '
                 'are handled locally by the worker.')
    flags.DEFINE_integer('num_styles', 12, 'Number of styles.')
    flags.DEFINE_integer('save_summaries_secs', 600,
                 'Frequency at which summaries are saved, in seconds.')
    flags.DEFINE_integer('save_interval_secs', 600,
                 'Frequency at which the model is saved, in seconds.')
    flags.DEFINE_integer('task', 0,
                 'Task ID. Used when training with multiple '
                 'workers to identify each worker.')
    flags.DEFINE_integer('train_steps', 2, 'Number of training steps.')
    flags.DEFINE_string('content_weights', DEFAULT_CONTENT_WEIGHTS,
                'Content weights')
    flags.DEFINE_string('master', '',
                'Name of the TensorFlow master to use.')
    flags.DEFINE_string('style_coefficients', None,
                'Scales the style weights conditioned on the style image.')
    flags.DEFINE_string('style_dataset_file', '/Users/guanjhensu/miniconda2/envs/magenta/lib/python2.7/site-packages/magenta/models/image_stylization/style_images_01.tfrecord', 'Style dataset file.')
    flags.DEFINE_string('style_weights', DEFAULT_STYLE_WEIGHTS, 'Style weights')
    flags.DEFINE_string('train_dir', '/Users/guanjhensu/miniconda2/envs/magenta/lib/python2.7/site-packages/magenta/models/image_stylization/train_checkpoint',
                'Directory for checkpoints and summaries.')
    FLAGS = flags.FLAGS


    def main(unused_argv=None):
      with tf.Graph().as_default():
        # Force all input processing onto CPU in order to reserve the GPU for the
        # forward inference and back-propagation.
        device = '/cpu:0' if not FLAGS.ps_tasks else '/job:worker/cpu:0'
        with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks,
                                              worker_device=device)):
          inputs, _ = image_utils.imagenet_inputs(FLAGS.batch_size,
                                          FLAGS.image_size)
          # Load style images and select one at random (for each graph execution, a
          # new random selection occurs)
  _, style_labels, style_gram_matrices = image_utils.style_image_inputs(
      os.path.expanduser(FLAGS.style_dataset_file),
      batch_size=FLAGS.batch_size, image_size=FLAGS.image_size,
      square_crop=True, shuffle=True)

with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
  # Process style and weight flags
  num_styles = FLAGS.num_styles
  if FLAGS.style_coefficients is None:
    style_coefficients = [1.0 for _ in range(num_styles)]
  else:
    style_coefficients = ast.literal_eval(FLAGS.style_coefficients)
  if len(style_coefficients) != num_styles:
    raise ValueError(
        'number of style coefficients differs from number of styles')
  content_weights = ast.literal_eval(FLAGS.content_weights)
  style_weights = ast.literal_eval(FLAGS.style_weights)

  # Rescale style weights dynamically based on the current style image
  style_coefficient = tf.gather(
      tf.constant(style_coefficients), style_labels)
  style_weights = dict([(key, style_coefficient * value)
                        for key, value in style_weights.iteritems()])

  # Define the model
  stylized_inputs = model.transform(
      inputs,
      normalizer_params={
          'labels': style_labels,
          'num_categories': num_styles,
          'center': True,
          'scale': True})

  # Compute losses.
  total_loss, loss_dict = learning.total_loss(
      inputs, stylized_inputs, style_gram_matrices, content_weights,
      style_weights)
  for key, value in loss_dict.iteritems():
    tf.summary.scalar(key, value)

  # Set up training
  optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
  train_op = slim.learning.create_train_op(
      total_loss, optimizer, clip_gradient_norm=FLAGS.clip_gradient_norm,
      summarize_gradients=False)

  # Function to restore VGG16 parameters
  # TODO(iansimon): This is ugly, but assign_from_checkpoint_fn doesn't
  # exist yet.
  saver = tf.train.Saver(slim.get_variables('vgg_16'))

  def init_fn(session):
    saver.restore(session, vgg.checkpoint_file())

  # Run training
  slim.learning.train(
      train_op=train_op,
      logdir=os.path.expanduser(FLAGS.train_dir),
      master=FLAGS.master,
      is_chief=FLAGS.task == 0,
      number_of_steps=FLAGS.train_steps,
      init_fn=init_fn,
      save_summaries_secs=FLAGS.save_summaries_secs,
      save_interval_secs=FLAGS.save_interval_secs)


    def console_entry_point():
      tf.app.run(main)


    if __name__ == '__main__':
      console_entry_point()
4

0 回答 0