我找到了一种方法来做到这一点,但不是那么简单。给定一个检查点,我们可以将其转换为序列化的 numpy 数组(或我们可能发现适合保存 numpy 数组字典的任何其他格式),如下所示:
checkpoint = {}
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, 'my_checkpoint.ckpt')
for x in tf.global_variables():
checkpoint[x.name] = x.eval()
np.save('checkpoint.npy', checkpoint)
可能有一些异常需要处理,但让我们保持代码简单。
然后,我们可以对 numpy 数组执行任何我们喜欢的操作:
checkpoint = np.load('checkpoint.npy')
checkpoint = ...
np.save('checkpoint.npy', checkpoint)
最后,我们可以在构建图表后手动加载权重,如下所示:
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
checkpoint = np.load('checkpoint.npy').item()
for key, data in checkpoint.iteritems():
var_scope = ... # to be extracted from key
var_name = ... #
with tf.variable_scope(var_scope, reuse=True):
var = tf.get_variable(var_name)
sess.run(var.assign(data))
如果有更直接的方法,我会全力以赴!