0

我正在尝试使用 TF2.1 加载检查点并保存它们的平均权重。我为它找到了 TF1 版本。 https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/avg_checkpoints.py

变量“检查点”是检查点路径的列表

  # Read variables from all checkpoints and average them.
  logger.info("Reading variables and averaging checkpoints:")
  for c in checkpoints:
    logger.info(c)

  var_list = tf.train.list_variables(checkpoints[0])

  var_values, var_dtypes = {}, {}
  for (name, shape) in var_list:
    if not name.startswith("global_step"):
      var_values[name] = tf.zeros(shape)

  for checkpoint in checkpoints:
    reader = tf.train.load_checkpoint(checkpoint)
    for name in var_values:
      tensor = tf.convert_to_tensor(reader.get_tensor(name))

      if tensor.dtype == tf.string:
        var_values[name] = tensor
      else:
        var_values[name] = tf.cast(var_values[name], tensor.dtype)
        var_values[name] += tensor
      var_dtypes[name] = tensor.dtype
    logger.info("Read from checkpoint %s", checkpoint)

  for name in var_values:  # Average.
    if var_dtypes[name] != tf.string:
      var_values[name] /= len(checkpoints)

冷你解释如何将平均值保存var_values到一个检查点?

4

1 回答 1

0

我可以通过参考同一问题的 Keras 版本来保存平均检查点,因为 Tensorflow 2.1 遵循 Keras API。

URL:keras 模型中的平均权重

于 2020-03-27T12:06:25.393 回答