我正在尝试tf.train.ExponentialMovingAverage与PartitionedVariable.
我使用 acustom_getter创建图表的 EMA 版本。
如果我不使用 apartitioner创建我的变量,则以下代码按预期工作:将变量设置为零后,衰减为 1,此变量的 EMA 版本保持原始值。
但是,如果我使用分区器,我会遇到以下问题
tf1.12ema_getter无法找到的平均值因此PartitionedVariable两个变量是同一个对象
tf.1.15我得到一个AttributeError: 'PartitionedVariable' object has no attribute 'experimental_ref'
这是我的代码
import tensorflow as tf
import numpy as np
def ema_getter(ema):
def _ema_getter(getter, name, *args, **kwargs):
var = getter(name, *args, **kwargs)
ema_var = ema.average(var)
if not ema_var:
tf.logging.warning(f"Unable to find EMA of {name}")
return ema_var if ema_var else var
return _ema_getter
if __name__ == "__main__":
use_partitioner = True
var = tf.get_variable(
name='var',
shape=[10, 2],
initializer=tf.ones_initializer(),
partitioner=tf.fixed_size_partitioner(2, axis=0) if use_partitioner else None
)
var_sum = tf.reduce_sum(var)
ema = tf.train.ExponentialMovingAverage(1.0)
variables = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, tf.get_variable_scope().name)
ema_op = ema.apply(variables)
with tf.variable_scope(tf.get_variable_scope(), reuse=True, custom_getter=ema_getter(ema)):
var_ema = tf.get_variable(
name='var',
shape=[10, 2],
partitioner=tf.fixed_size_partitioner(2, axis=0) if use_partitioner else None
)
print(f"EMA variable name: {var_ema.name}")
var_ema_sum = tf.reduce_sum(var_ema)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(ema_op)
print(sess.run(var_sum)) # 20.0
print(sess.run(var_ema_sum)) # 20.0
sess.run(tf.assign(var, tf.zeros_like(var)))
sess.run(ema_op)
print(sess.run(var_sum)) # 0.0
print(sess.run(var_ema_sum)) # should be 20.0