1

我正在尝试对具有相同结构但使用不同数据集进行训练的两个模型的张量进行平均。模型存储在 ckpt 文件中。

我试图从 tensor2tensor 查看avg_checkpoints 函数,但不知道如何使用它。

我该如何解决这个问题?

from tensor2tensor.utils import avg_checkpoints

print(avg_checkpoints.checkpoint_exists("/"))
#I got true from console
#I have copied final ckpt from different model to the root file

avg_checkpoint.main(?)
#no idea what to replace the ? with
4

1 回答 1

2

avg_checkpoints.py是一个可执行脚本,因此您可以从命令行使用它,例如:

python utils/avg_checkpoints.py
  --checkpoints path/to/checkpoint1,path/to/checkpoint2
  --num_last_checkpoints 2
  --output_path where/to/save/the/output

请注意,如果这两个检查点是从头开始在不同的数据集上训练的,则平均将不起作用。如果您有一个预训练模型,您只需在两个不同的数据集上进行微调,那么平均可以工作。

您可以平均两个以上的检查点。为每个检查点添加权重的一种简单但简单的方法是将其包含多次--checkpoints(并相应地增加num_last_checkpoints)。

于 2019-07-30T00:54:26.073 回答