我一直在尝试(粗略地)gpflow
SVGP
在玩具数据集上训练和保存模型,很大程度上遵循这个笔记本示例
保存模型后使用pickle
(我很欣赏不推荐这样做,但我不认为这是这里的主要问题),我发现了一些不寻常的,我认为是意外行为:如果我们不打电话gpflow.utilities.freeze(model)
,在尝试之前pickle
model
,那么我们得到一个错误。如果我们调用gpflow.utilities.freeze(model)
(丢弃返回的冻结模型),则model
可以正确腌制。
重现
最小的、可重现的例子
import numpy as np
import gpflow
import tensorflow as tf
import pickle
rng = np.random.RandomState(123)
N = 10000 # Number of training observations
X = rng.rand(N, 1)
Y = rng.randn(N, 1)
data = (X, Y)
n_inducing_vars = 100
Z = X[:n_inducing_vars]
minibatch_size = 100
n_iterations = 100
#Define model object
model = gpflow.models.SVGP(gpflow.kernels.Matern12(), gpflow.likelihoods.Bernoulli(), inducing_variable=Z, num_data=N)
#Create minibatch object
data_minibatch = (
tf.data.Dataset.from_tensor_slices(data).prefetch(
N).repeat().shuffle(N).batch(minibatch_size)
)
data_minibatch_it = iter(data_minibatch)
model_objective = model.training_loss_closure(data_minibatch_it)
#Define optimiser
optimizer = tf.keras.optimizers.Adam(0.001)
#Optimise both variational parameters and kernel hyperparameters.
for step in range(n_iterations):
optimizer.minimize(model_objective,
var_list=model.trainable_variables
)
freeze = False
if not freeze:
# pickle doesn't work
pickle.dump(model, open('test1', 'wb'))
else:
# if following code is executed, pickle works fine
_ = gpflow.utilities.freeze(model) # ignore return value
pickle.dump(model, open('test1', 'wb'))
堆栈跟踪或错误消息
TypeError Traceback (most recent call last)
<ipython-input-6-3d5f537ca994> in <module>
----> 1 pickle.dump(model, open('test1', 'wb'))
TypeError: can't pickle HashableWeakRef objects
预期行为
并不是说我希望泡菜在第一个实例中起作用,因为我知道这不是tensorflow
一般保存相关对象的推荐方法。但是,我当然不希望它在第一次失败但在第二次成功。从代码库来看,我不认为gpflow.utilities.freeze(model)
应该是 mutating model
,它似乎正在这样做。
系统信息
- 使用 GPflow 版本 2.0.0 ... 2.0.4 进行测试
- TensorFlow 版本:2.1.0,tensorflow_probability 0.9.0
- Python版本:Python 3.6.9
我猜想调用freeze
它model
实际上是莫名其妙地转换model
为“冻结”模型,然后具有“常量”属性(https://gpflow.readthedocs.io/en/master/notebooks/intro_to_gpflow2.html#TensorFlow- saved_model ) 使其能够被腌制。
对此事的任何澄清将不胜感激。
请注意,我issue
在gpflow
github
( https://github.com/GPflow/GPflow/issues/1493 ) 上发布了此问题,但决定应在此处向更广泛的 gpflow 社区广播此问题。