最近,我尝试使用多个 GPU 来加速训练。但是我对批量标准化有一些问题。具体来说,当我使用ExponentialMovingAverage来获得平均批次平均值和变化时,准确性很差。
我尝试了几种方法(函数)来实现批量标准化,如下所示。其余部分代码都是一样的,我只是尝试了不同的批量归一化函数。使用 2 个 GPU 时,方法 2-4 运行良好,但方法 1 在测试数据集上的准确度确实很差。当我切换到只使用 1 个 GPU 时,所有方法都运行良好。
数据集为 CIFAR10,batch size 为 128。当使用 2 个 GPU 时,每个 GPU 处理 64 个样本,然后对每个 GPU 的梯度进行平均,就像 tensorflow 教程CIFAR10 multi-gpu一样。
我的 tensorflow 版本是 1.1.0,python 版本是 2.7,操作系统是 ubuntu 16.04。CUDA 版本是 7.5.17。
我的问题是为什么使用 tf.train.ExponentialMovingAverage(方法 1)在使用多个 GPU 时效果如此糟糕?我真的很困惑。
方法一:
def batch_norm_conv(x, n_out = 3, phase_train=True, scope='bn_conv'):
with tf.variable_scope(scope):
beta = tf.get_variable('beta_conv', shape=[n_out], initializer=tf.constant_initializer(0.0))
gamma = tf.get_variable('gamma_conv', shape=[n_out], initializer=tf.constant_initializer(1.0))
batch_mean_temp, batch_var_temp = tf.nn.moments(x, [0,1,2], name='moments')
batch_mean = tf.get_variable('batch_mean', shape=batch_mean_temp.get_shape(), initializer=tf.constant_initializer(0.0), trainable=False)
batch_var = tf.get_variable('batch_var', shape=batch_var_temp.get_shape(), initializer=tf.constant_initializer(0.0), trainable=False)
mean_op = tf.assign(batch_mean, batch_mean_temp)
var_op = tf.assign(batch_var, batch_var_temp)
ema = tf.train.ExponentialMovingAverage(decay=0.5, zero_debias=False)
ema_apply_op = ema.apply([batch_mean, batch_var])
def mean_var_with_update():
with tf.control_dependencies([mean_op, var_op]):
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
mean, var = tf.cond(phase_train,
mean_var_with_update,
lambda: (ema.average(batch_mean), ema.average(batch_var)))
normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-5)
return normed
方法二:
def batch_norm_conv(x, n_out = 3, phase_train=True, scope='bn_conv'):
with tf.variable_scope(scope):
beta = tf.get_variable('beta_conv', shape=[n_out], initializer=tf.constant_initializer(0.0))
gamma = tf.get_variable('gamma_conv', shape=[n_out], initializer=tf.constant_initializer(1.0))
batch_mean, batch_var = tf.nn.moments(x, [0,1,2], name='moments')
print(batch_mean.get_shape())
mean_average = tf.get_variable('mean_average', shape=batch_mean.get_shape(), initializer=tf.constant_initializer(0.0))
var_average = tf.get_variable('var_average', shape=batch_var.get_shape(), initializer=tf.constant_initializer(0.0))
decay=0.5
def mean_var_with_update():
mean_temp = decay * mean_average + (1-decay) * batch_mean
var_temp = decay * var_average + (1-decay) * batch_var
mean_op = tf.assign(mean_average, mean_temp)
var_op = tf.assign(var_average, var_temp)
with tf.control_dependencies([mean_op, var_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
mean, var = tf.cond(phase_train, mean_var_with_update, lambda: (mean_average, var_average))
normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3)
return normed
方法三:
from tensorflow.python.training import moving_averages
def batch_norm_conv(x, n_out = 3, phase_train=True, scope='bn_conv'):
with tf.variable_scope(scope):
beta = tf.get_variable('beta_conv', shape=[n_out], initializer=tf.constant_initializer(0.0))
gamma = tf.get_variable('gamma_conv', shape=[n_out], initializer=tf.constant_initializer(1.0))
batch_mean, batch_var = tf.nn.moments(x, [0,1,2], name='moments')
moving_mean = tf.get_variable('batch_mean', shape=batch_mean.get_shape(), initializer=tf.constant_initializer(0.0), trainable=False)
moving_variance = tf.get_variable('batch_var', shape=batch_var.get_shape(), initializer=tf.constant_initializer(0.0), trainable=False)
update_moving_mean = moving_averages.assign_moving_average(moving_mean,
batch_mean, 0.5, zero_debias=False)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, batch_var, 0.5, zero_debias=False)
def mean_var_with_update():
with tf.control_dependencies([update_moving_mean, update_moving_variance]):
return tf.identity(batch_mean), tf.identity(batch_var)
mean, var = tf.cond(phase_train,
mean_var_with_update,
lambda: (moving_mean, moving_variance))
normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3)
return normed
方法四:
def batch_norm_conv(x, phase_train=True, scope='bn_conv'):
with tf.variable_scope(scope):
normed = tf.contrib.layers.batch_norm(x,
center=True,
scale=True,
is_training = phase_train,
decay = 0.5,
trainable = True
)
return normed
方法 1 的结果:
('epoch', 0, 'test accuracy:', 0.12970000132918358)
('epoch', 1, 'test accuracy:', 0.20419999957084656)
('epoch', 2, 'test accuracy:', 0.11649999991059304)
('epoch', 3, 'test accuracy:', 0.12790000066161156)
('epoch', 4, 'test accuracy:', 0.17040000036358832)
('epoch', 5, 'test accuracy:', 0.15139999836683274)
('epoch', 6, 'test accuracy:', 0.13050000220537186)
('epoch', 7, 'test accuracy:', 0.15879999995231628)
('epoch', 8, 'test accuracy:', 0.17370000183582307)
('epoch', 9, 'test accuracy:', 0.17910000011324884)
('epoch', 10, 'test accuracy:', 0.17960000038146973)
('epoch', 11, 'test accuracy:', 0.12400000095367432)
('epoch', 12, 'test accuracy:', 0.13669999763369561)
('epoch', 13, 'test accuracy:', 0.25510000437498093)
('epoch', 14, 'test accuracy:', 0.18769999742507934)
('epoch', 15, 'test accuracy:', 0.16730000004172324)
('epoch', 16, 'test accuracy:', 0.15510000288486481)
('epoch', 17, 'test accuracy:', 0.19639999866485597)
('epoch', 18, 'test accuracy:', 0.24789999574422836)
('epoch', 19, 'test accuracy:', 0.15929999947547913)
('epoch', 20, 'test accuracy:', 0.17439999729394912)
方法 2 - 4 的结果(它们没有太大区别,所以只发布其中一个):
('epoch', 0, 'test accuracy:', 0.27250000238418581)
('epoch', 1, 'test accuracy:', 0.42709999978542329)
('epoch', 2, 'test accuracy:', 0.50179999470710757)
('epoch', 3, 'test accuracy:', 0.56709998846054077)
('epoch', 4, 'test accuracy:', 0.59760001301765442)
('epoch', 5, 'test accuracy:', 0.66010000705718996)
('epoch', 6, 'test accuracy:', 0.65400000214576726)
('epoch', 7, 'test accuracy:', 0.69880000352859495)
('epoch', 8, 'test accuracy:', 0.69749999642372129)
('epoch', 9, 'test accuracy:', 0.71029999256134035)
('epoch', 10, 'test accuracy:', 0.72619999051094053)
('epoch', 11, 'test accuracy:', 0.72920000553131104)
('epoch', 12, 'test accuracy:', 0.7372000098228455)
('epoch', 13, 'test accuracy:', 0.75380001068115232)
('epoch', 14, 'test accuracy:', 0.74269998073577881)
('epoch', 15, 'test accuracy:', 0.76199999451637268)
('epoch', 16, 'test accuracy:', 0.7636999785900116)
('epoch', 17, 'test accuracy:', 0.76039999723434448)
('epoch', 18, 'test accuracy:', 0.77150000333786006)
('epoch', 19, 'test accuracy:', 0.77920001149177553)
('epoch', 20, 'test accuracy:', 0.79100000858306885)