我正在尝试对具有相同结构但使用不同数据集进行训练的两个模型的张量进行平均。模型存储在 ckpt 文件中。
我试图从 tensor2tensor 查看avg_checkpoints 函数,但不知道如何使用它。
我该如何解决这个问题?
from tensor2tensor.utils import avg_checkpoints
print(avg_checkpoints.checkpoint_exists("/"))
#I got true from console
#I have copied final ckpt from different model to the root file
avg_checkpoint.main(?)
#no idea what to replace the ? with