我目前正在研究顺序设置中变分自动编码器的变体,其中任务是拟合/恢复一系列实值观察数据(因此这是一个回归问题)。
我已经使用tf.keras
启用了急切执行和 tensorflow_probability (tfp) 来构建我的模型。遵循 VAE 概念,生成网络发出观察数据的分布参数,我将其建模为多元正态。因此,输出是预测分布的均值和 logvar。
关于训练过程,损失的第一部分是重建误差。这是真实观察的对数似然,给定来自生成网络的预测(参数)分布。在这里,我使用tfp.distributions
,因为它既快速又方便。
然而,在训练完成后,以相当低的损失值为标志,事实证明我的模型似乎没有学到任何东西。模型的预测值在时间维度上几乎没有变化(回想一下,问题是连续的)。
尽管如此,为了进行完整性检查,当我用 MSE 损失替换对数似然度(在 VAE 上工作时这是不合理的)时,它会产生非常好的数据拟合。所以我得出结论,这个对数似然项一定有问题。有没有人对此有一些线索和/或解决方案?
我曾考虑用交叉熵损失替换对数似然,但我认为这不适用于我的情况,因为我的问题是回归并且数据无法归一化为 [0,1] 范围。
当使用对数似然作为重建损失时,我还尝试实现退火 KL 项(即用常数 < 1 加权 KL 项)。但它也没有奏效。
这是我的原始损失函数的代码片段(使用对数似然作为重建错误):
import tensorflow as tf
tfe = tf.contrib.eager
tf.enable_eager_execution()
import tensorflow_probability as tfp
tfd = tfp.distributions
def loss(model, inputs):
outputs, _ = SSM_model(model, inputs)
#allocate the corresponding output component
infer_mean = outputs[:,:,:latent_dim] #mean of latent variable from inference net
infer_logvar = outputs[:,:,latent_dim : (2 * latent_dim)]
trans_mean = outputs[:,:,(2 * latent_dim):(3 * latent_dim)] #mean of latent variable from transition net
trans_logvar = outputs[:,:, (3 * latent_dim):(4 * latent_dim)]
obs_mean = outputs[:,:,(4 * latent_dim):((4 * latent_dim) + output_obs_dim)] #mean of observation from generative net
obs_logvar = outputs[:,:,((4 * latent_dim) + output_obs_dim):]
target = inputs[:,:,2:4]
#transform logvar to std
infer_std = tf.sqrt(tf.exp(infer_logvar))
trans_std = tf.sqrt(tf.exp(trans_logvar))
obs_std = tf.sqrt(tf.exp(obs_logvar))
#computing loss at each time step
time_step_loss = []
for i in range(tf.shape(outputs)[0].numpy()):
#distribution of each module
infer_dist = tfd.MultivariateNormalDiag(infer_mean[i],infer_std[i])
trans_dist = tfd.MultivariateNormalDiag(trans_mean[i],trans_std[i])
obs_dist = tfd.MultivariateNormalDiag(obs_mean[i],obs_std[i])
#log likelihood of observation
likelihood = obs_dist.prob(target[i]) #shape = 1D = batch_size
likelihood = tf.clip_by_value(likelihood, 1e-37, 1)
log_likelihood = tf.log(likelihood)
#KL of (q|p)
kl = tfd.kl_divergence(infer_dist, trans_dist) #shape = batch_size
#the loss
loss = - log_likelihood + kl
time_step_loss.append(loss)
time_step_loss = tf.convert_to_tensor(time_step_loss)
overall_loss = tf.reduce_sum(time_step_loss)
overall_loss = tf.cast(overall_loss, dtype='float32')
return overall_loss