1

我正在尝试tf.train.ExponentialMovingAveragePartitionedVariable.

我使用 acustom_getter创建图表的 EMA 版本。

如果我不使用 apartitioner创建我的变量,则以下代码按预期工作:将变量设置为零后,衰减为 1,此变量的 EMA 版本保持原始值。

但是,如果我使用分区器,我会遇到以下问题

tf1.12ema_getter无法找到的平均值因此PartitionedVariable两个变量是同一个对象

tf.1.15我得到一个AttributeError: 'PartitionedVariable' object has no attribute 'experimental_ref'

这是我的代码

import tensorflow as tf
import numpy as np


def ema_getter(ema):
    def _ema_getter(getter, name, *args, **kwargs):
        var = getter(name, *args, **kwargs)
        ema_var = ema.average(var)
        if not ema_var:
            tf.logging.warning(f"Unable to find EMA of {name}")
        return ema_var if ema_var else var

    return _ema_getter


if __name__ == "__main__":
    use_partitioner = True
    var = tf.get_variable(
        name='var',
        shape=[10, 2],
        initializer=tf.ones_initializer(),
        partitioner=tf.fixed_size_partitioner(2, axis=0) if use_partitioner else None
    )
    var_sum = tf.reduce_sum(var)
    ema = tf.train.ExponentialMovingAverage(1.0)
    variables = tf.get_collection(
        tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
    ema_op = ema.apply(variables)

    with tf.variable_scope(tf.get_variable_scope(), reuse=True, custom_getter=ema_getter(ema)):
        var_ema = tf.get_variable(
            name='var',
            shape=[10, 2],
            partitioner=tf.fixed_size_partitioner(2, axis=0) if use_partitioner else None
        )
        print(f"EMA variable name: {var_ema.name}")
        var_ema_sum = tf.reduce_sum(var_ema)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(ema_op)
        print(sess.run(var_sum))  # 20.0
        print(sess.run(var_ema_sum))  # 20.0
        sess.run(tf.assign(var, tf.zeros_like(var)))
        sess.run(ema_op)
        print(sess.run(var_sum))  # 0.0
        print(sess.run(var_ema_sum))  # should be 20.0
4

1 回答 1

1

到目前为止,我的理解是,PartitionedVariable它不作为标准Variable,而仅仅是 other 列表的外壳Variable

custom_getter需要考虑到这一点,并使用原始变量的 ema 版本手动检索和重建ParitionedVariablea PartitionedVariable

但是,这似乎很棘手——使用 a.__class__因为我找不到以PartitionedVariable干净方式导入的方法,或者访问_partitionsex 的私有属性。

在这里分享我当前的修复

import tensorflow as tf
import numpy as np


def ema_getter(ema):
    def _ema_getter(getter, name, *args, **kwargs):
        var = getter(name, *args, **kwargs)
        # Manually reconstruct if PartitionedVariable
        if var.__class__.__name__ == "PartitionedVariable":
            ema_vs = [ema.average(v) for v in var]
            ema_var = var.__class__(
                name=var.name,
                shape=var.shape,
                dtype=var.dtype,
                variable_list=ema_vs,
                partitions=var._partitions,
            )
        else:
            ema_var = ema.average(var)

        if not ema_var:
            tf.logging.warning(f"Unable to find EMA of {name}")

        return ema_var if ema_var else var

    return _ema_getter


if __name__ == "__main__":
    use_partitioner = True
    var = tf.get_variable(
        name='var',
        shape=[10, 2],
        initializer=tf.ones_initializer(),
        partitioner=tf.fixed_size_partitioner(2, axis=0) if use_partitioner else None
    )
    var_sum = tf.reduce_sum(var)
    ema = tf.train.ExponentialMovingAverage(1.0)
    variables = tf.get_collection(
        tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
    ema_op = ema.apply(variables)

    with tf.variable_scope(tf.get_variable_scope(), reuse=True, custom_getter=ema_getter(ema)):
        var_ema = tf.get_variable(
            name='var',
            shape=[10, 2],
            partitioner=tf.fixed_size_partitioner(2, axis=0) if use_partitioner else None
        )
        print(f"EMA variable name: {var_ema.name}")
        var_ema_sum = tf.reduce_sum(var_ema)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(ema_op)
        print(sess.run(var_sum))
        print(sess.run(var_ema_sum))
        sess.run(tf.assign(var, tf.zeros_like(var)))
        sess.run(ema_op)
        print(sess.run(var_sum))
        print(sess.run(var_ema_sum))
于 2019-12-24T11:18:04.933 回答