我正在 TensorFlow 中开发一个 VAE 项目,其中编码器/解码器网络内置于函数中。这个想法是能够保存,然后加载训练好的模型并使用编码器功能进行采样。
恢复模型后,我无法让解码器函数运行并将恢复的、经过训练的变量返回给我,出现“未初始化值”错误。我认为这是因为该功能要么创建一个新的,要么覆盖现有的,或者以其他方式。但我不知道如何解决这个问题。这是一些代码:
class VAE(object):
def __init__(self, restore=True):
self.session = tf.Session()
if restore:
self.restore_model()
self.build_decoder = tf.make_template('decoder', self._build_decoder)
@staticmethod
def _build_decoder(z, output_size=768, hidden_size=200,
hidden_activation=tf.nn.elu, output_activation=tf.nn.sigmoid):
x = tf.layers.dense(z, hidden_size, activation=hidden_activation)
x = tf.layers.dense(x, hidden_size, activation=hidden_activation)
logits = tf.layers.dense(x, output_size, activation=output_activation)
return distributions.Independent(distributions.Bernoulli(logits), 2)
def sample_decoder(self, n_samples):
prior = self.build_prior(self.latent_dim)
samples = self.build_decoder(prior.sample(n_samples), self.input_size).mean()
return self.session.run([samples])
def restore_model(self):
print("Restoring")
self.saver = tf.train.import_meta_graph(os.path.join(self.save_dir, "turbolearn.meta"))
self.saver.restore(self.sess, tf.train.latest_checkpoint(self.save_dir))
self._restored = True
想跑samples = vae.sample_decoder(5)
在我的训练程序中,我运行:
if self.checkpoint:
self.saver.save(self.session, os.path.join(self.save_dir, "myvae"), write_meta_graph=True)
更新
根据以下建议的答案,我更改了还原方法
self.saver = tf.train.Saver()
self.saver.restore(self.session, tf.train.latest_checkpoint(self.save_dir))
但是现在在创建 Saver() 对象时得到一个值错误:
ValueError: No variables to save