我想知道如何将 inception v3 的优化器从 SGD 更改为 Adam Optimizer。我想保持预先训练的权重,而不是默认的 SGD 优化器。如果我添加 Adam 优化器,它会抛出一个错误,说它在预训练的检查点文件中找不到 Adam 优化器:
NotFoundError (see above for traceback): Key OptimizeLoss/InceptionV3/Mixed_6b/Branch_2/Conv2d_0d_7x1/BatchNorm/beta/Adam_1 not found in checkpoint
[[Node: save_1/RestoreV2_525 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_1/Const_0_0, save_1/RestoreV2_525/tensor_names, save_1/RestoreV2_525/shape_and_slices)]]
[[Node: save_1/Assign_758/_1522 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_3805_save_1/Assign_758", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]
只训练神经网络的顶层而不训练 inception 效果很好。下面是用于实现 inception v3 的代码(来自 im2xt):
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3_base
slim = tf.contrib.slim
def inception_v3(images,
trainable=True,
is_training=True,
weight_decay=0.00004,
stddev=0.1,
dropout_keep_prob=0.8,
use_batch_norm=True,
batch_norm_params=None,
add_summaries=True,
scope="InceptionV3"):
"""Builds an Inception V3 subgraph for image embeddings.
Args:
images: A float32 Tensor of shape [batch, height, width, channels].
trainable: Whether the inception submodel should be trainable or not.
is_training: Boolean indicating training mode or not.
weight_decay: Coefficient for weight regularization.
stddev: The standard deviation of the trunctated normal weight initializer.
dropout_keep_prob: Dropout keep probability.
use_batch_norm: Whether to use batch normalization.
batch_norm_params: Parameters for batch normalization. See
tf.contrib.layers.batch_norm for details.
add_summaries: Whether to add activation summaries.
scope: Optional Variable scope.
Returns:
end_points: A dictionary of activations from inception_v3 layers.
"""
# Only consider the inception model to be in training mode if it's trainable.
is_inception_model_training = trainable and is_training
if use_batch_norm:
# Default parameters for batch normalization.
if not batch_norm_params:
batch_norm_params = {
"is_training": is_inception_model_training,
"trainable": trainable,
# Decay for the moving averages.
"decay": 0.9997,
# Epsilon to prevent 0s in variance.
"epsilon": 0.001,
# Collection containing the moving mean and moving variance.
"variables_collections": {
"beta": None,
"gamma": None,
"moving_mean": ["moving_vars"],
"moving_variance": ["moving_vars"],
}
}
else:
batch_norm_params = None
if trainable:
weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
else:
weights_regularizer = None
with tf.variable_scope(scope, "InceptionV3", [images]) as scope:
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_regularizer=weights_regularizer,
trainable=trainable):
with slim.arg_scope(
[slim.conv2d],
weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params):
net, end_points = inception_v3_base(images, scope=scope)
with tf.variable_scope("logits"):
shape = net.get_shape()
net = slim.avg_pool2d(net, shape[1:3], padding="VALID", scope="pool")
net = slim.dropout(
net,
keep_prob=dropout_keep_prob,
is_training=is_inception_model_training,
scope="dropout")
net = slim.flatten(net, scope="flatten")
# Add summaries.
if add_summaries:
for v in end_points.values():
tf.contrib.layers.summaries.summarize_activation(v)
return net
下面是用于训练模型的代码:
def main(unused_argv):
assert INPUT_FILE_PATTERN, "--input_file_pattern is required"
assert TRAIN_DIR, "--train_dir is required"
model_config = configuration.ModelConfig()
model_config.input_file_pattern = INPUT_FILE_PATTERN
model_config.inception_checkpoint_file = FLAGS.inception_checkpoint_file
model_config.train_inception = FLAGS.fine_tune
training_config = configuration.TrainingConfig()
model_config.pos_weight = training_config.pos_weight
model_config.negatives = training_config.negatives
model_config.class_weights = calculate_class_weights(
VOCAB_FILEPATH,
model_config.vocab_size,
use_class_weights=FLAGS.use_class_weights,
min_class_weight=training_config.min_class_weight
)
# Create training directory.
train_dir = TRAIN_DIR
if not tf.gfile.IsDirectory(train_dir):
tf.logging.info("Creating training directory: %s", train_dir)
tf.gfile.MakeDirs(train_dir)
# Build the TensorFlow graph.
g = tf.Graph()
with g.as_default():
# Build the model.
model = filter_inception_model.FilterInceptionModel(
model_config, mode="train")
model.build()
# Set up the learning rate.
learning_rate_decay_fn = None
learning_rate = tf.constant(training_config.learning_rate)
if training_config.learning_rate_decay_factor > 0:
num_batches_per_epoch = (training_config.num_examples_per_epoch /
model_config.batch_size)
decay_steps = int(num_batches_per_epoch *
training_config.num_epochs_per_decay)
def _learning_rate_decay_fn(learning_rate, global_step):
return tf.train.exponential_decay(
learning_rate,
global_step,
decay_steps=decay_steps,
decay_rate=training_config.learning_rate_decay_factor,
staircase=True)
learning_rate_decay_fn = _learning_rate_decay_fn
# Set up the training ops.
train_op = tf.contrib.layers.optimize_loss(
loss=model.total_loss,
global_step=model.global_step,
learning_rate=learning_rate,
optimizer=training_config.optimizer,
clip_gradients=training_config.clip_gradients,
learning_rate_decay_fn=learning_rate_decay_fn,
summaries=["gradients"]
)
# Set up the Saver for saving and restoring model checkpoints.
saver = tf.train.Saver(max_to_keep=training_config.max_checkpoints_to_keep)
# Run training.
steps = round((FLAGS.number_of_epochs*
training_config.num_examples_per_epoch)/model_config.batch_size)
steps = int(steps)
print("Total number of steps to process: %d" % steps)
if steps != 0:
tf.contrib.slim.learning.train(
train_op,
train_dir,
log_every_n_steps=FLAGS.log_every_n_steps,
graph=g,
global_step=model.global_step,
save_summaries_secs=1800,
number_of_steps=steps,
init_fn=model.init_fn,
saver=saver,
#session_wrapper=tf_debug.LocalCLIDebugWrapperSession
)
我将不胜感激对此的任何见解
编辑:添加错误消息