在检查了一个检查点(我们称之为模型 1)之后,我获得了下面的变量名列表(为简单起见缩短了):
var_list = ["ex1_model/fc2/b",
"ex1_model/fc2/b/Adam",
"ex1_model/fc2/b/Adam_1",
"ex1_model/fc2/w",
"ex1_model/fc2/w/Adam"]
假设我有一个更大的模型 2,并想用模型 1 的值初始化它的一部分。
从这里描述的名称中获取变量(因为我没有找到一种简单的方法):
def get_vars_by_name(names):
return [v for v in tf.global_variables() if v.name in names]
构建模型 2 和保护程序以恢复:
logits = build_model(inputs)
saver = tf.train.Saver(var_list=get_vars_by_name(var_list))
在
saver.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
我收到错误:
"ex1_model/fc2/w/Adam" [...] raise ValueError("No variables to save")
请帮我找出我犯的错误。我也很感激一种更简单的方法,因为这很糟糕。谢谢你。