我正在尝试使用 TF2.1 加载检查点并保存它们的平均权重。我为它找到了 TF1 版本。 https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/avg_checkpoints.py
变量“检查点”是检查点路径的列表
# Read variables from all checkpoints and average them.
logger.info("Reading variables and averaging checkpoints:")
for c in checkpoints:
logger.info(c)
var_list = tf.train.list_variables(checkpoints[0])
var_values, var_dtypes = {}, {}
for (name, shape) in var_list:
if not name.startswith("global_step"):
var_values[name] = tf.zeros(shape)
for checkpoint in checkpoints:
reader = tf.train.load_checkpoint(checkpoint)
for name in var_values:
tensor = tf.convert_to_tensor(reader.get_tensor(name))
if tensor.dtype == tf.string:
var_values[name] = tensor
else:
var_values[name] = tf.cast(var_values[name], tensor.dtype)
var_values[name] += tensor
var_dtypes[name] = tensor.dtype
logger.info("Read from checkpoint %s", checkpoint)
for name in var_values: # Average.
if var_dtypes[name] != tf.string:
var_values[name] /= len(checkpoints)
冷你解释如何将平均值保存var_values
到一个检查点?