2

最近,我尝试使用多个 GPU 来加速训练。但是我对批量标准化有一些问题。具体来说,当我使用ExponentialMovingAverage来获得平均批次平均值和变化时,准确性很差。

我尝试了几种方法(函数)来实现批量标准化,如下所示。其余部分代码都是一样的,我只是尝试了不同的批量归一化函数。使用 2 个 GPU 时,方法 2-4 运行良好,但方法 1 在测试数据集上的准确度确实很差。当我切换到只使用 1 个 GPU 时,所有方法都运行良好。

数据集为 CIFAR10,batch size 为 128。当使用 2 个 GPU 时,每个 GPU 处理 64 个样本,然后对每个 GPU 的梯度进行平均,就像 tensorflow 教程CIFAR10 multi-gpu一样。

我的 tensorflow 版本是 1.1.0,python 版本是 2.7,操作系统是 ubuntu 16.04。CUDA 版本是 7.5.17。

我的问题是为什么使用 tf.train.ExponentialMovingAverage(方法 1)在使用多个 GPU 时效果如此糟糕?我真的很困惑。

方法一:

def batch_norm_conv(x, n_out = 3, phase_train=True, scope='bn_conv'):
    with tf.variable_scope(scope):
        beta = tf.get_variable('beta_conv', shape=[n_out], initializer=tf.constant_initializer(0.0))
        gamma = tf.get_variable('gamma_conv', shape=[n_out], initializer=tf.constant_initializer(1.0))

        batch_mean_temp, batch_var_temp = tf.nn.moments(x, [0,1,2], name='moments')
        batch_mean = tf.get_variable('batch_mean', shape=batch_mean_temp.get_shape(), initializer=tf.constant_initializer(0.0), trainable=False)
        batch_var = tf.get_variable('batch_var', shape=batch_var_temp.get_shape(), initializer=tf.constant_initializer(0.0), trainable=False)

        mean_op = tf.assign(batch_mean, batch_mean_temp)
        var_op = tf.assign(batch_var, batch_var_temp)

        ema = tf.train.ExponentialMovingAverage(decay=0.5, zero_debias=False)
        ema_apply_op = ema.apply([batch_mean, batch_var])
        def mean_var_with_update():
            with tf.control_dependencies([mean_op, var_op]):
                with tf.control_dependencies([ema_apply_op]):
                    return tf.identity(batch_mean), tf.identity(batch_var)

        mean, var = tf.cond(phase_train,
                        mean_var_with_update,
                        lambda: (ema.average(batch_mean), ema.average(batch_var)))    
        normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-5)
    return normed

方法二:

def batch_norm_conv(x, n_out = 3, phase_train=True, scope='bn_conv'):
  with tf.variable_scope(scope):
      beta = tf.get_variable('beta_conv', shape=[n_out], initializer=tf.constant_initializer(0.0))
      gamma = tf.get_variable('gamma_conv', shape=[n_out], initializer=tf.constant_initializer(1.0))

      batch_mean, batch_var = tf.nn.moments(x, [0,1,2], name='moments')
      print(batch_mean.get_shape())
      mean_average = tf.get_variable('mean_average', shape=batch_mean.get_shape(), initializer=tf.constant_initializer(0.0))
      var_average = tf.get_variable('var_average', shape=batch_var.get_shape(), initializer=tf.constant_initializer(0.0))

      decay=0.5
      def mean_var_with_update():  
          mean_temp = decay * mean_average + (1-decay) * batch_mean
          var_temp = decay * var_average + (1-decay) * batch_var
          mean_op = tf.assign(mean_average, mean_temp)
          var_op = tf.assign(var_average, var_temp)
          with tf.control_dependencies([mean_op, var_op]):
              return tf.identity(batch_mean), tf.identity(batch_var)

      mean, var = tf.cond(phase_train, mean_var_with_update, lambda: (mean_average, var_average))
      normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3)
  return normed

方法三:

from tensorflow.python.training import moving_averages

def batch_norm_conv(x, n_out = 3, phase_train=True, scope='bn_conv'):

  with tf.variable_scope(scope):
      beta = tf.get_variable('beta_conv', shape=[n_out], initializer=tf.constant_initializer(0.0))
      gamma = tf.get_variable('gamma_conv', shape=[n_out], initializer=tf.constant_initializer(1.0))

      batch_mean, batch_var = tf.nn.moments(x, [0,1,2], name='moments')

      moving_mean = tf.get_variable('batch_mean', shape=batch_mean.get_shape(), initializer=tf.constant_initializer(0.0), trainable=False)
      moving_variance = tf.get_variable('batch_var', shape=batch_var.get_shape(), initializer=tf.constant_initializer(0.0), trainable=False)

      update_moving_mean = moving_averages.assign_moving_average(moving_mean,
                                                                 batch_mean, 0.5, zero_debias=False)
      update_moving_variance = moving_averages.assign_moving_average(
                                                          moving_variance, batch_var, 0.5, zero_debias=False)
      def mean_var_with_update():
          with tf.control_dependencies([update_moving_mean, update_moving_variance]):
                  return tf.identity(batch_mean), tf.identity(batch_var)

      mean, var = tf.cond(phase_train, 
                                        mean_var_with_update,
                                        lambda: (moving_mean, moving_variance))

      normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3)
  return normed

方法四:

def batch_norm_conv(x, phase_train=True, scope='bn_conv'):

  with tf.variable_scope(scope):
      normed = tf.contrib.layers.batch_norm(x, 
                                            center=True, 
                                            scale=True, 
                                            is_training = phase_train,
                                            decay = 0.5,
                                            trainable = True
                                           )

  return normed

方法 1 的结果:

('epoch', 0, 'test accuracy:', 0.12970000132918358)
('epoch', 1, 'test accuracy:', 0.20419999957084656)
('epoch', 2, 'test accuracy:', 0.11649999991059304)
('epoch', 3, 'test accuracy:', 0.12790000066161156)
('epoch', 4, 'test accuracy:', 0.17040000036358832)
('epoch', 5, 'test accuracy:', 0.15139999836683274)
('epoch', 6, 'test accuracy:', 0.13050000220537186)
('epoch', 7, 'test accuracy:', 0.15879999995231628)
('epoch', 8, 'test accuracy:', 0.17370000183582307)
('epoch', 9, 'test accuracy:', 0.17910000011324884)
('epoch', 10, 'test accuracy:', 0.17960000038146973)
('epoch', 11, 'test accuracy:', 0.12400000095367432)
('epoch', 12, 'test accuracy:', 0.13669999763369561)
('epoch', 13, 'test accuracy:', 0.25510000437498093)
('epoch', 14, 'test accuracy:', 0.18769999742507934)
('epoch', 15, 'test accuracy:', 0.16730000004172324)
('epoch', 16, 'test accuracy:', 0.15510000288486481)
('epoch', 17, 'test accuracy:', 0.19639999866485597)
('epoch', 18, 'test accuracy:', 0.24789999574422836)
('epoch', 19, 'test accuracy:', 0.15929999947547913)
('epoch', 20, 'test accuracy:', 0.17439999729394912)

方法 2 - 4 的结果(它们没有太大区别,所以只发布其中一个):

('epoch', 0, 'test accuracy:', 0.27250000238418581)
('epoch', 1, 'test accuracy:', 0.42709999978542329)
('epoch', 2, 'test accuracy:', 0.50179999470710757)
('epoch', 3, 'test accuracy:', 0.56709998846054077)
('epoch', 4, 'test accuracy:', 0.59760001301765442)
('epoch', 5, 'test accuracy:', 0.66010000705718996)
('epoch', 6, 'test accuracy:', 0.65400000214576726)
('epoch', 7, 'test accuracy:', 0.69880000352859495)
('epoch', 8, 'test accuracy:', 0.69749999642372129)
('epoch', 9, 'test accuracy:', 0.71029999256134035)
('epoch', 10, 'test accuracy:', 0.72619999051094053)
('epoch', 11, 'test accuracy:', 0.72920000553131104)
('epoch', 12, 'test accuracy:', 0.7372000098228455)
('epoch', 13, 'test accuracy:', 0.75380001068115232)
('epoch', 14, 'test accuracy:', 0.74269998073577881)
('epoch', 15, 'test accuracy:', 0.76199999451637268)
('epoch', 16, 'test accuracy:', 0.7636999785900116)
('epoch', 17, 'test accuracy:', 0.76039999723434448)
('epoch', 18, 'test accuracy:', 0.77150000333786006)
('epoch', 19, 'test accuracy:', 0.77920001149177553)
('epoch', 20, 'test accuracy:', 0.79100000858306885)
4

0 回答 0