0

我一直在试图弄清楚如何计算 ResNet 反向传递中的 Flops 数量。对于前向传递,这似乎很简单:将卷积过滤器应用于每一层的输入。但是,在反向传递期间,Flops 如何计算梯度计算和所有权重的更新?

具体来说,

  • 如何在每一层的梯度计算中计算 Flops?

  • 需要计算所有梯度,以便计算每个梯度的翻牌数?

  • Pool、BatchNorm 和 Relu 层的梯度计算中有多少次 Flops?

我了解梯度计算的链式法则,但是很难确定它如何应用于 ResNet 的卷积层中的权重过滤器以及每个需要多少次 Flops。获得有关计算向后传递的总触发器的方法的任何评论将非常有用。谢谢

4

1 回答 1

0

您绝对可以手动计算反向传递的乘法和加法的数量,但我想这对于复杂模型来说是一个详尽的过程。

通常,对于 CNN 和其他模型,大多数模型都以触发器为基准进行正向传递,而不是反向触发器计数。我想原因与推理在应用程序中的不同 CNN 变体和其他深度学习模型方面更为重要有关。

向后传球仅在训练时很重要,对于大多数简单模型,向后和向前翻牌应该接近一些常数因素。

因此,我尝试了一种 hacky 方法来计算图中整个 resnet 模型的梯度,以获得前向传递和梯度计算的 flop 计数,然后减去前向 flop。这不是一个精确的测量,可能会错过复杂图形/模型的许多操作。

但这可能会给出大多数模型的失败估计。

[以下代码片段适用于 tensorflow 2.0]

import tensorflow as tf

def get_flops():

    for_flop = 0
    total_flop = 0
    session = tf.compat.v1.Session()
    graph = tf.compat.v1.get_default_graph()

    # forward
    with graph.as_default():
        with session.as_default():

            model = tf.keras.applications.ResNet50() # change your model here

            run_meta = tf.compat.v1.RunMetadata()
            opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()

            # We use the Keras session graph in the call to the profiler.
            flops = tf.compat.v1.profiler.profile(graph=graph,
                                                  run_meta=run_meta, cmd='op', options=opts)

            for_flop = flops.total_float_ops
            # print(for_flop)

    # forward + backward
    with graph.as_default():
        with session.as_default():

            model = tf.keras.applications.ResNet50() # change your model here


            outputTensor = model.output 
            listOfVariableTensors = model.trainable_weights
            gradients = tf.gradients(outputTensor, listOfVariableTensors)

            run_meta = tf.compat.v1.RunMetadata()
            opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()

            # We use the Keras session graph in the call to the profiler.
            flops = tf.compat.v1.profiler.profile(graph=graph,
                                                  run_meta=run_meta, cmd='op', options=opts)

            total_flop = flops.total_float_ops
            # print(total_flop)

    return for_flop, total_flop


for_flops, total_flops = get_flops()
print(f'forward: {for_flops}')
print(f'backward: {total_flops - for_flops}')

出去:

51112224
102224449
forward: 51112224
backward: 51112225
于 2020-05-05T08:28:06.970 回答