3

我有一个张量流检查点,在使用常规例程重新定义与它对应的图形后,我可以加载该检查点tf.train.Saver()saver.restore(session, 'my_checkpoint.ckpt').

但是,现在,我想修改网络的第一层以接受形状的输入 say[200, 200, 1]而不是[200, 200, 10].

为此,我想将第一层对应的张量的形状从[3, 3, 10, 32](3x3 内核,10 个输入通道,32 个输出通道)修改为[3, 3, 1, 32]通过第三维求和。

我怎么能那样做?

4

2 回答 2

1

我找到了一种方法来做到这一点,但不是那么简单。给定一个检查点,我们可以将其转换为序列化的 numpy 数组(或我们可能发现适合保存 numpy 数组字典的任何其他格式),如下所示:

checkpoint = {}
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, 'my_checkpoint.ckpt')
    for x in tf.global_variables():
        checkpoint[x.name] = x.eval()
    np.save('checkpoint.npy', checkpoint)

可能有一些异常需要处理,但让我们保持代码简单。

然后,我们可以对 numpy 数组执行任何我们喜欢的操作:

checkpoint = np.load('checkpoint.npy')
checkpoint = ...
np.save('checkpoint.npy', checkpoint)

最后,我们可以在构建图表后手动加载权重,如下所示:

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    checkpoint = np.load('checkpoint.npy').item()
    for key, data in checkpoint.iteritems():
        var_scope = ... # to be extracted from key
        var_name = ...  # 
        with tf.variable_scope(var_scope, reuse=True):
            var = tf.get_variable(var_name)
            sess.run(var.assign(data))

如果有更直接的方法,我会全力以赴!

于 2018-01-07T20:56:50.780 回答
1

您可以使用 tensorflow::BundleReader 读取源 ckpt,并使用 tensorflow::BundleWriter 对其进行重写。

tensorflow::BundleReader reader(Env::Default(), model_path_prefix);
std::vector<std::string> tensor_names;
reader.Seek("");
reader.Next();
for (; reader.Valid(); reader.Next()) {
    tensor_names.emplace_back(reader.key());
}
tensorflow::BundleWriter writer(Env::Default(), new_model_path_prefix);   
for (auto &tensor_name : tensor_names) {
        DataType dtype;
        TensorShape shape;        
        
        reader.LookupDtypeAndShape(tensor_name, &dtype, &shape);
        Tensor val(dtype, shape);
        Status bool_ret  = reader.Lookup(tensor_name, &val);
        std::cout << tensor_name << " " << val.DebugString() << std::endl;
        // modify dtype and shape. padding Tensor
        TensorSlice slice(new_shape.dims());
        writer.AddSlice(tensor_name, new_shape, slice, new_val);
    }
}
writer.Finish();
于 2021-04-07T02:17:35.747 回答