-2

我目前正在使用 GPR 实现一个带有 GPflow 的算法。我想在 GPR 训练后保存参数并加载模型进行测试。有人知道命令吗?

4

2 回答 2

3

GPflow现在有一个包含提示和技巧的页面。您可以点击链接找到您问题的答案。但是,我还要在这里粘贴 MWE:

假设你想存储 GPR 模型,你可以这样做gpflow.Saver()

kernel = gpflow.kernels.RBF(1)
x = np.random.randn(100, 1)
y = np.random.randn(100, 1)
model = gpflow.models.GPR(x, y, kernel)

filename = "/tmp/gpr.gpflow"
path = Path(filename)
if path.exists():
    path.unlink()
saver = gpflow.saver.Saver()
saver.save(filename, model)

要重新加载它,您必须使用以下任一解决方案:

with tf.Graph().as_default() as graph, tf.Session().as_default():
    model_copy = saver.load(filename)

或者,如果您想在之前存储模型的同一会话中加载模型,则需要应用一些技巧:

ctx_for_loading = gpflow.saver.SaverContext(autocompile=False)
model_copy = saver.load(filename, context=ctx_for_loading)
model_copy.clear()
model_copy.compile()

2020 年 6 月 1 日更新

GPflow 2.0 不提供自定义保护程序。它依赖于 TensorFlow 检查点和tf.saved_model. 您可以在此处找到示例: GPflow 介绍

于 2019-01-30T01:21:53.587 回答
0

我为 gpflow 模型采用的一种选择是仅保存和加载可训练数据。它假设您有一个构建和编译模型的函数。我在下面通过将变量保存到 hdf5 文件来展示这一点。

import h5py

def _load_model(model, load_file):
    """
    Load a model given by model path
    """

    vars = {}
    def _gather(name, obj):
        if isinstance(obj, h5py.Dataset):
            vars[name] = obj[...]

    with h5py.File(load_file) as f:
        f.visititems(_gather)

    model.assign(vars)

def _save_model(model, save_file):
    vars = model.read_trainables()
    with h5py.File(save_file) as f:
        for name, value in vars.items():
            f[name] = value
于 2019-01-30T11:10:33.147 回答