1

我正在使用 gpflow 进行多输出回归。

我的回归目标是一个三维向量(相关的),我设法用完整的协方差矩阵进行了预测。这是我的实现。更具体地说,我在 tensorflow 之后使用 SVGP,其中 f_x, Y 是张量(我使用的是小批量训练)。在批量训练期间,小批量中的训练示例是相关的,而我们可以假设它们与训练集中的其他示例相互独立。因此,我想在训练期间在每个批次中实现相关内核,并使用完整的协方差矩阵进行预测。

kernel = mk.SharedIndependentMok(gpflow.kernels.RBF(args.feat_dim, ARD=False, name="rbf"), args.output_dim)

# kernel = mk.SeparateIndependentMok([gpflow.kernels.RBF(128, ARD=True, name="rbf_ard"+str(i)) for i in range(3)])
q_mu = np.zeros((args.batch_size, args.output_dim)).reshape(args.batch_size * args.output_dim, 1)
q_sqrt = np.eye(args.batch_size * args.output_dim).reshape(1, args.batch_size * args.output_dim, args.batch_size * args.output_dim)
# feature = gpflow.features.InducingPoints(np.zeros((args.batch_size, 128)))

self.gp_model = gpflow.models.SVGP(X=f_X, Y=Y, kern=kernel, likelihood=gpflow.likelihoods.Gaussian(name="lik"), Z=np.zeros((args.batch_size, args.feat_dim)), q_mu=q_mu, q_sqrt=q_sqrt, name="svgp")

我有两个问题:

  1. 我应该使用哪个内核?我在https://gpflow.readthedocs.io/en/develop/notebooks/multioutput.html#Shared-Independent-MOK-&-Shared-Independent-Features-(SLOW-CODE)中阅读了关于多输出 GP 的教程,谁能进一步解释以下声明?在我的理解中,不同输出维度之间的内核是不相关的,但是,我们仍然可以在预测中得到完整的协方差矩阵(条件)?

所有输出都是不相关的,并且每个输出使用相同的内核。但是,在条件计算期间,我们不假设这种特定的块对角结构。

  1. 一个技术问题。当我在 tensorflow 之后使用 gpflow 时,我应该如何设置“功能”。我应该设置一个带有零的numpy matix还是简单地设置为None?似乎在批量训练中,诱导点默认设置为全批量。虽然如何设置特征类型,例如 SharedIndependentMof,我们可以设置大于批量大小的诱导点吗?

谢谢!

4

1 回答 1

2

该笔记本描述了多输出框架及其计算方面。您突出显示的部分只是进行独立的 GP 回归,但目的是展示利用不同块对角结构的不同特征如何具有不同的计算属性。在小批量中纠正(相关?)训练示例并假设它们在小批量之间是独立的,很难在数学意义上准确理解您的意思。从推导变分下限的角度来看,这些是结果而不是假设。但是,根据您的描述,您似乎想要:

  • 使用小批量进行训练(即不需要考虑小批量之间的相关性 [Hensman et al 2013])
  • 提供输出之间相关性的内核。目前我们只支持SeparateMixedMok. Alvarez 等人对数学进行了很好的回顾。
  • 利用先验 ( MixedKernelSeparateMof) 中的独立属性的特征。
  • 具有完全协方差的预测。

本节对此进行了介绍3. Mixed Kernel & Uncorrelated features (OPTIMAL)。使用通常的预测功能时,只需确保设置full_output_cov=True.

小批量大小与诱导点的数量无关,尽管将诱导点的数量设置为大于训练数据的数量通常是没有用的。上面和笔记本中描述了要选择的功能。多输出的情况很复杂。

注意:您也可以通过 GitHub https://gpflow.readthedocs.io/en/master/notebooks/advanced/multioutput.html查看笔记本。

于 2018-12-29T13:03:42.683 回答