1

SVGPnatural gradient等 GPflow 的文档中,TensorFlow 中的 Adam 优化器用于使用随机变分推理技术训练 GP 模型的模型参数(长度尺度、方差、诱导输入等),而自然梯度优化器对于变分参数。一个片段如下所示

def run_adam(model, iterations):
    """
    Utility function running the Adam optimizer

    :param model: GPflow model
    :param interations: number of iterations
    """
    # Create an Adam Optimizer action
    logf = []
    train_iter = iter(train_dataset.batch(minibatch_size))
    training_loss = model.training_loss_closure(train_iter, compile=True)
    optimizer = tf.optimizers.Adam()

    @tf.function
    def optimization_step():
        optimizer.minimize(training_loss, model.trainable_variables)

    for step in range(iterations):
        optimization_step()
        if step % 10 == 0:
            elbo = -training_loss().numpy()
            logf.append(elbo)
    return logf

如图所示,model.trainable_variables 被传递给 Adam 优化器,该优化器继承自 tf.Module,由包括长度尺度和方差在内的多个参数组成。

我关心的是 Adam 优化器是否正在处理模型参数的无约束或约束版本。一段测试代码运行如下

import gpflow as gpf
import numpy as np

x = np.arange(10)[:, np.newaxis]
y = np.arange(10)[:, np.newaxis]
model = gpf.models.GPR((x, y), 
                       kernel = gpf.kernels.SquaredExponential(variance = 2, lengthscales = 3), 
                       noise_variance = 4)

model.kernel.parameters[0].unconstrained_variable is model.trainable_variables[0]

并返回

True

据我所知,高斯过程的参数(如长度尺度和内核的方差)不是负数,在训练时应该受到约束。我不是 GPflow 或 TensorFlow 的源代码专家,但似乎 Adam 正在研究不受约束的参数。这只是对我的误解,还是别的什么?

提前感谢您的帮助!

4

1 回答 1

3

你是对的,这是设计使然。GPflow 中的受约束变量由 表示ParameterParameter包裹unconstrained_variable. _ 当您访问.trainable_variables您的模型时,这将包括 的unconstrained_variableParameter因此当您将这些变量传递给优化器时,优化器将训练这些变量而不是Parameter自身。

但是您的模型看不到unconstrained_value,它看到的Parameter接口是与通过可选转换tf.Tensor包装的类似接口的接口。unconstrained_variable此转换将不受约束的值映射到受约束的值。因此,您的模型只会看到受约束的值。您的约束值必须为正值不是问题,转换会将未约束值的负值映射到约束值的正值。

您可以看到内核的第一个不受约束和受约束的值Parameter,以及与它们相关的变换

param = model.kernel.parameters[0]
param.value()  # this is what your model will see
param.unconstrained_variable  # this is what the optimizer will see
param.transform  # the above two are related via this
于 2020-05-25T10:33:14.643 回答