我正在尝试在 TF2.x 上实现梯度累积。我发现的所有实现都适用于 TF1.x 或旧的 keras 接口。我认为那里没有实现(尽管我很高兴被证明是错误的)。


import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Flatten, Dense
from tqdm import tqdm
import matplotlib.pyplot as plt

class SimpleTrainStepModel(Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        if len(data) == 3:
            x, y, sample_weight = data
            (x, y), sample_weight = data, None

        with tf.GradientTape() as tape:
            y_pred = self(x, training = True)  # Forward pass
            loss = self.compiled_loss(y, y_pred, sample_weight = sample_weight, regularization_losses = self.losses)
        gradients = tape.gradient(loss, self.trainable_variables)
        self.compiled_metrics.update_state(y, y_pred)

        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        return {m.name: m.result() for m in self.metrics}

class GradAccumModel(Model):
    def fit(self, *args, batch_size = 32, grad_accum = 1, **kwargs):
        self.train_function = None
        if batch_size % grad_accum != 0:
            raise ValueError('Batch size must be divisible by the Gradient accumulation steps, dummy!')
        self.grad_accum = grad_accum
        self.batch_size = batch_size
        return super(GradAccumModel, self).fit(*args,
                                               batch_size = self.batch_size,
                                               #validation_batch_size = validation_batch_size,#self.batch_size//grad_accum if validation_batch_size is None else validation_batch_size,

    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        if len(data) == 3:
            x, y, sample_weight = data
            (x, y), sample_weight = data, None

        step = self.batch_size // self.grad_accum

        # def _slice_nested(obj, i, j):
        #     if type(obj) is list:
        #         return [o[i:j] for o in obj]
        #     else:
        #         return obj[i:j]

        with tf.GradientTape() as tape:
            y_pred = self(x[:step], training = True)  # Forward pass
            loss = self.compiled_loss(y[:step], y_pred, sample_weight = sample_weight, regularization_losses = self.losses)
        gradients = tape.gradient(loss, self.trainable_variables)
        self.compiled_metrics.update_state(y[:step], y_pred)

        i = tf.constant(step)
        # tf.print('TF - HERE!')
        def cond(i, *args):
            return i < self.batch_size
        def body(i, grad):
            # tf.print('\tTF - HERE!')
            with tf.GradientTape() as tape:
                y_pred = self(x[i:i + step], training = True) # Forward pass
                loss = self.compiled_loss(y[i:i + step], y_pred, sample_weight = sample_weight, regularization_losses = self.losses)
            _grad = tape.gradient(loss, self.trainable_variables)

            for g,_g in zip(grad, _grad):
                g += _g

            self.compiled_metrics.update_state(y[i:i + step], y_pred)
            return [i + step, grad]

        i, gradients = tf.while_loop(cond, body, [i, gradients], parallel_iterations = 1)

        # for g in gradients:        # I tried with and without division co calculate the mean
        #     g *= 1/self.grad_accum #

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        # Update metrics (includes the metric that tracks the loss)

        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

if __name__ == '__main__':
    (x_train, y_train), (x_valid, y_valid) = tf.keras.datasets.mnist.load_data()

    for MODEL, ga_kwarg, colour in list(zip([Model, SimpleTrainStepModel, GradAccumModel, GradAccumModel],
                                            [{}, {}, {'grad_accum': 1}, {'grad_accum': 6}],
                                            ['blue', 'green', 'yellow', 'red'])):

        for _ in tqdm(range(10)):
            # tf.random.set_seed(0)
            x = Input((28, 28))
            y = x
            y = Flatten()(y)
            y = Dense(128, activation = 'sigmoid')(y)
            y = Dense(10, activation = 'softmax')(y)

            model = MODEL(x, y)
            model.compile(loss = tf.keras.losses.SparseCategoricalCrossentropy(),
                          optimizer = tf.keras.optimizers.Adam(1e-4),
                          metrics = ['acc'])

            hist = model.fit(x_train, y_train, validation_data = (x_valid, y_valid), verbose = 0, batch_size = 6000, epochs = 100, **ga_kwarg)
            plt.plot(hist.history['val_acc'], color = colour, alpha = .25)


我已经能够验证它确实节省了 gpu 内存。但是,最终的结果与正常的不一样Model.fit



如您所见,前三个Model.fits 很好地聚类并给出相同的结果。但是当while循环开始发挥作用时,训练就完全不同了。



1 回答 1



from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer as lso
from tensorflow.python.distribute import parameter_server_strategy
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.util import nest
from tensorflow.keras.models import Model as _Model

class Model(_Model):
    def fit(self, *args, batch_size: int = 32, grad_accum_steps: int = 1, **kwargs):
        Shallow wrapper of Model.fit that captures batch_size and additional kwarg: grad_accum.

        batch_size : int
            same as in Model.fit
        grad_accum_steps : int
            Number of steps to split batch_size into. The `batch_size` should be divisible by `grad_accum` (defaults to 1).
        if grad_accum_steps == 1:
            super().fit(*args, batch_size = batch_size, **kwargs)

        self.train_function = None
        num_workers = ds_context.get_strategy().num_replicas_in_sync
        if batch_size % (grad_accum_steps * num_workers) != 0:
            raise ValueError(f'Batch size ({batch_size}) must be divisible by the Gradient accumulation steps ({grad_accum_steps}), and the number of replicas ({num_workers}), dummy!')

        self._grad_accum_ = grad_accum_steps
        self._batch_size_ = batch_size
        self._num_workers_ = num_workers
        train_step_backup = self.train_step
        self.train_step = self._train_step_
        out = super(self).fit(*args,
                              batch_size = self._batch_size_, # TODO maybe consider validation batch size

        del self._grad_accum_
        del self._batch_size_
        del self._num_workers_
        self.train_step = train_step_backup
        return out

    def _train_step_(self, data):
        Custom training step taking into account gradient accumulation for low memory training

        if len(data) == 3:
            x, y, sample_weight = data
            (x, y), sample_weight = data, None

        def slice_map(struct, start, stop): # dealing with nasty nested structures
            if struct is None:
                return None # special case for sample_weight

            return nest.map_structure(lambda x: x[start:stop], struct)

        # ---------- GRAD ACCUM STUFF ----------------------------------------------------------------------------------
        step = self._batch_size_ // self._num_workers_ // self._grad_accum_
        x_ = slice_map(x, 0, step)
        y_ = slice_map(y, 0, step)
        w_ = slice_map(sample_weight, 0, step)

        with tf.GradientTape() as tape:

            y_pred = self(x_, training = True)  # Forward pass
            loss = self.compiled_loss(y_, y_pred, sample_weight = w_, regularization_losses = self.losses)
            if isinstance(self.optimizer, lso.LossScaleOptimizer):
                loss = self.optimizer.get_scaled_loss(loss)

        gradients = tape.gradient(loss, self.trainable_variables)
        gradients = [gradient * (1./self._grad_accum_) for gradient in gradients]
        self.compiled_metrics.update_state(y_, y_pred)

        i = tf.constant(step)
        def cond(i, *args):
            return i < self._batch_size_

        def body(i, grad):
            x_ = slice_map(x, i, i + step)
            y_ = slice_map(y, i, i + step)
            w_ = slice_map(sample_weight, i, i + step)

            with tf.GradientTape() as tape:
                y_pred = self(x_, training = True) # Forward pass
                loss = self.compiled_loss(y_, y_pred, sample_weight = w_, regularization_losses = self.losses)
                if isinstance(self.optimizer, lso.LossScaleOptimizer):
                    loss = self.optimizer.get_scaled_loss(loss)

            _grad = tape.gradient(loss, self.trainable_variables)
            _grad = [_g * (1./self._grad_accum_) for _g in _grad]

            grad = [g + _g for g,_g in zip(grad, _grad)]

            self.compiled_metrics.update_state(y_, y_pred)
            return [i + step, grad]

        i, gradients = tf.while_loop(cond, body, [i, gradients], parallel_iterations = 1)
        # --------------------------------------------------------------------------------------------------------------

        # ---------- STUFF FROM Model._minimize ------------------------------------------------------------------------
        aggregate_grads_outside_optimizer = (self.optimizer._HAS_AGGREGATE_GRAD and not isinstance(self.distribute_strategy.extended, parameter_server_strategy.ParameterServerStrategyExtended))

        if aggregate_grads_outside_optimizer: # TODO there might be some issues with the scaling, due to the extra accumulation steps
            gradients = self.optimizer._aggregate_gradients(zip(gradients, self.trainable_variables))

        if isinstance(self.optimizer, lso.LossScaleOptimizer):
            gradients = self.optimizer.get_unscaled_gradients(gradients)

        gradients = self.optimizer._clip_gradients(gradients)
        if self.trainable_variables:
            if aggregate_grads_outside_optimizer:
                self.optimizer.apply_gradients(zip(gradients, self.trainable_variables), experimental_aggregate_gradients = False)
                self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        # --------------------------------------------------------------------------------------------------------------

        return {m.name: m.result() for m in self.metrics}
于 2021-03-17T14:03:01.663 回答