0

我正在尝试将 gpflow (2.0rc) 与 float64 一起使用,并且一直在努力让即使是简单的示例也能正常工作。我使用以下方式配置 gpflow:

gpflow.config.set_default_float(np.float64)

我正在使用 GPR:

# Model construction:
k = gpflow.kernels.Matern52(variance=1.0, lengthscale=0.3)
m = gpflow.models.GPR((X, Y), kernel=k)
m.likelihood.variance = 0.01

事实上,如果我打印一个摘要,两个参数都有 dtype float64。但是,如果我尝试使用此模型进行预测,则会出现错误。

tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute AddV2 as input #1(zero-based) was expected to be a double tensor but is a float tensor [Op:AddV2] name: add/

调试会话将我带到 gpr.py 中的以下行(第 88 行)

s = tf.linalg.diag(tf.fill([num_data], self.likelihood.variance))

这将创建一个 dtype 为 float32 的矩阵,从而导致如上所述的爆炸。所以问题似乎与我设置可能性方差的方式有关?

这是演示该问题的完整 Python 脚本:

import numpy as np
import gpflow

gpflow.config.set_default_float(np.float64)

# data:
X = np.random.rand(10, 1)
Y = np.sin(X)

assert X.dtype == np.float64
assert Y.dtype == np.float64

# Model construction:
k = gpflow.kernels.Matern52(variance=1.0, lengthscale=0.3)
m = gpflow.models.GPR((X, Y), kernel=k)
m.likelihood.variance = 0.01

gpflow.utilities.print_summary(m)

# Predict
xx = np.array([[1.0]])
assert xx.dtype == np.float64

mean, var = m.predict_y(xx)
print(f'mean: {mean}')
print(f'var: {var}')
4

1 回答 1

0

我已经检查过了,显然可以在这里回答你自己的问题。

在进一步调查并在 github 上提出问题后,我想添加一些后续内容。

对我的问题的快速回答是,是的,问题在于我设置差异的方式。如果我这样设置:

m.likelihood.variance.assign(0.01)

那么我的代码就可以了。(尽管这并不完全等价,因为似然方差现在是一个可训练的参数。)

但这只是 Tensorflow 的 API 更广泛问题的一种表现 - 如此处提出的:https ://github.com/tensorflow/tensorflow/issues/26033

在此问题得到解决之前,尝试在任何基于 tensorflow 构建的库中使用 float64s 而不是默认的 float32 可能会很痛苦。

如果您尝试使用 float64s,它会导致对用户不利的 API 功能,因为在某些地方 Python 浮点参数很好,而在其他地方,它们需要转换为正确的浮点大小。

这是一个来自 tensorflow-probability 的示例,它演示了这个问题:

adaptive_hmc = tfp.mcmc.SimpleStepSizeAdaptation(
    hmc, 
    num_adaptation_steps=10, 
    target_accept_prob=0.75,
    adaptation_rate=0.1
)

比较 target_accept_prob 和 adaption_rate 参数。要真正让它工作,需要使用类似“tf.cast(0.75, dtype=tf.float64)”的东西来转换 target_accept_prob 参数,但无论使用 float32 还是 float64,adaption_rate 参数都很好。您需要深入研究源代码或等待异常并尝试找出哪个参数导致它知道需要此类转换的位置。

于 2020-02-04T16:42:26.367 回答