1

我一直在尝试(粗略地)gpflow SVGP在玩具数据集上训练和保存模型,很大程度上遵循这个笔记本示例

保存模型后使用pickle(我很欣赏不推荐这样做,但我不认为这是这里的主要问题),我发现了一些不寻常的,我认为是意外行为:如果我们不打电话gpflow.utilities.freeze(model),在尝试之前pickle model,那么我们得到一个错误。如果我们调用gpflow.utilities.freeze(model)(丢弃返回的冻结模型),则model可以正确腌制。

重现

最小的、可重现的例子

import numpy as np
import gpflow
import tensorflow as tf
import pickle
rng = np.random.RandomState(123)

N = 10000  # Number of training observations
X = rng.rand(N, 1)
Y = rng.randn(N, 1)
data = (X, Y)

n_inducing_vars = 100
Z = X[:n_inducing_vars]
minibatch_size = 100
n_iterations = 100

#Define model object
model = gpflow.models.SVGP(gpflow.kernels.Matern12(), gpflow.likelihoods.Bernoulli(), inducing_variable=Z, num_data=N)

#Create minibatch object
data_minibatch = (
tf.data.Dataset.from_tensor_slices(data).prefetch(
    N).repeat().shuffle(N).batch(minibatch_size)
    )
data_minibatch_it = iter(data_minibatch)
model_objective = model.training_loss_closure(data_minibatch_it)

#Define optimiser
optimizer = tf.keras.optimizers.Adam(0.001)
#Optimise both variational parameters and kernel hyperparameters.
for step in range(n_iterations):
    optimizer.minimize(model_objective,
                       var_list=model.trainable_variables
                       )

freeze = False
if not freeze:
    # pickle doesn't work
    pickle.dump(model, open('test1', 'wb'))
else:
    # if following code is executed, pickle works fine
    _ = gpflow.utilities.freeze(model)  # ignore return value
    pickle.dump(model, open('test1', 'wb'))

堆栈跟踪或错误消息

TypeError                                 Traceback (most recent call last)
<ipython-input-6-3d5f537ca994> in <module>
----> 1 pickle.dump(model, open('test1', 'wb'))

TypeError: can't pickle HashableWeakRef objects

预期行为

并不是说我希望泡菜在第一个实例中起作用,因为我知道这不是tensorflow一般保存相关对象的推荐方法。但是,我当然不希望它在第一次失败但在第二次成功。从代码库来看,我不认为gpflow.utilities.freeze(model)应该是 mutating model,它似乎正在这样做。

系统信息

  • 使用 GPflow 版本 2.0.0 ... 2.0.4 进行测试
  • TensorFlow 版本:2.1.0,tensorflow_probability 0.9.0
  • Python版本:Python 3.6.9

我猜想调用freezemodel实际上是莫名其妙地转换model为“冻结”模型,然后具有“常量”属性(https://gpflow.readthedocs.io/en/master/notebooks/intro_to_gpflow2.html#TensorFlow- saved_model ) 使其能够被腌制。

对此事的任何澄清将不胜感激。

请注意,我issuegpflow github( https://github.com/GPflow/GPflow/issues/1493 ) 上发布了此问题,但决定应在此处向更广泛的 gpflow 社区广播此问题。

4

2 回答 2

1

此行为适用于使用 tensorflow_probability 的双射器的任何代码/模型,并且不限于 SVGP 模型。在 GPflow 中,双射器用于约束参数,例如确保核方差和长度尺度始终为正。

基本解释是 tensorflow_probability 的双射器保留了他们操作过的张量的缓存,例如,这允许他们在以下示例中准确地恢复原始张量:

import tensorflow as tf
import tensorflow_probability as tfp
bij = tfp.bijectors.Exp()
x = tf.constant(1.2345)
y = bij.forward(x)
assert bij.inverse(y) is x  # actual object identity, not just numerical equivalence

然而,这些缓存使用无法腌制的 HashableWeakRef 对象 - 甚至无法复制(使用 Python stdlib 的copy.deepcopy函数)。

只有当你通过双射器实际运行张量时,缓存才会被填充——如果你只是创建模型而不优化它,你可以腌制(或复制)它就好了。但当然,这通常不是很有用。

为了解决这个问题并允许复制“使用过的”(例如经过训练的)模型,我们有gpflow.utilities.reset_cache_bijectors(). 这被调用gpflow.utilities.deepcopy()以允许复制。反过来需要进行深度复制以便为您提供冻结副本,而gpflow.utilities.freeze()不是原地冻结模型,这解释了轻微的副作用。

因此,这不是freeze使您能够成功腌制它所必需的;在酸洗之前添加调用就足够了reset_cache_bijectors(model),将示例中的代码替换为

if not freeze:
    gpflow.utilities.reset_cache_bijectors(model)  # with this added call, pickle *does* work
    pickle.dump(model, open('test1', 'wb'))

最终,这是一个只能在上游通过 tensorflow_probability 在他们自己的代码中“正确”修复的问题。更多详细信息可以在 awav 对 tensorflow_probability 的这个请求中找到,该请求旨在解决这个问题。

附带说明一下,正如markvdw所指出的,您可能会发现存储使用获得的模型的所有参数值gpflow.utilities.read_values()(它返回值的参数键的字典)更容易,您可以以任何您喜欢的方式存储它,然后重新- 通过首先重新创建对象然后使用gpflow.utilities.multiple_assign().

于 2020-06-03T14:32:22.063 回答
1

让我们看看下面几行发生了什么:

if not freeze:
    # pickle doesn't work
    pickle.dump(model, open('test1', 'wb'))  # Line 1
else:
    # if following code is executed, pickle works fine
    gpflow.utilities.freeze(model)           # Line 2
    pickle.dump(model, open('test1', 'wb'))  # Line 3

Line 1经过训练的模型中,包含Parameters将 TensorFlow 概率双射器作为从受约束空间到不受约束空间并返回的变换器的实例。TFP 双射器缓存所有正向和反向计算。双射器的缓存是用映射实现的,其中键是正向和反向函数的张量输入,值是返回的张量对象。不幸的是,张量(例如 np.arrays)不能被散列,为此 TFP 实现了HashableWeakRef张量的包装器。错误消息“TypeError: can't pickle HashableWeakRef objects”具有误导性。这实际上意味着HashableWeakRefpython 无法复制该实例,仅仅是因为它是对尚未创建的对象的引用。因此,这些对象不能被腌制。

Line 2中,您有freeze由两个调用组成的方法:第一个调用删除双射器的内容,即缓存,第二个调用是copy.deepcopy。背后的魔力freeze在于它删除了引用。是的,它修改了现有对象,但它既不影响急切计算也不影响tf.functioned 函数。清洁使之成为deepcopy可能。

Line 3有效,因为该对象没有对复制的引用。

此问题有很长的报告记录,并尝试在 GPflow 中修复它: TFP#547TFP#944GPflow#1479GPflow#1293GPflow#1338

这是 TensorFlow 概率中的建议修复:TFP#947

于 2020-06-03T15:09:17.380 回答