1

我觉得我真的不知道我在做什么,所以我会描述我认为我在做什么,我想做什么以及失败的地方。

给定一个正常的变分自编码器:

...
net = tf.layers.dense(net, units=code_size * 2, activation=None)
mean = net[:, :code_size]
std = net[:, code_size:]
posterior = tfd.MultivariateNormalDiagWithSoftplusScale(mean, std)
net = posterior.sample()
net = tf.layers.dense(net, units=input_size, ...)
...

我想我在做什么:让神经网络找到一个“均值”和“标准差”值,并用它来创建正态分布(高斯)。从该分布中采样并将其用于解码器。换句话说:学习编码的高斯分布

现在我想对混合高斯做同样的事情。

...
net = tf.layers.dense(net, units=code_size * 2 * code_size, activation=None)

means, stds = tf.split(net, 2, axis=-1)

means = tf.split(means, code_size, axis=-1)
stds = tf.split(stds, code_size, axis=-1)

components = [tfd.MultivariateNormalDiagWithSoftplusScale(means[i], stds[i]) for i in range(code_size)]
probs = [1.0 / code_size] * code_size

gauss_mix = tfd.Mixture(cat=tfd.Categorical(probs=probs), components=components)
net = gauss_mix.sample()
net = tf.layers.dense(net, units=input_size, ...)
...

这对我来说似乎相对简单,只是它失败并出现以下错误:

形状 () 和 (?,) 不兼容

这似乎来自probs没有批量维度(我不认为它需要那个)。

我认为这probs定义了组件之间的概率。

如果我定义一个probs也有批次维度的,我会得到以下神秘错误,我不知道它应该是什么意思:

维度 -1796453376 必须 >= 0

我通常会误解一些概念吗?

或者我需要做些什么不同的事情?

4

0 回答 0