1

我有n网络,每个网络都有相同的输入/输出。我想根据分类分布随机选择一个输出。Tfp.Categorical只输出整数,我试图做类似的事情

act_dist = tfp.distributions.Categorical(logits=act_logits) # act_logits are all the same, so the distribution is uniform
rand_out = act_dist.sample()
x = nn_out1 * tf.cast(rand_out == 0., dtype=tf.float32) + ... # for all my n networks

但是rand_out == 0.总是假的,以及其他条件。

有什么想法可以实现我的需要吗?

4

2 回答 2

1

您还可以查看 MixtureSameFamily,它为您提供了一个秘密集合。

nn_out1 = tf.expand_dims(nn_out1, axis=2)
...
outs = tf.concat([nn_out1, nn_nout2, ...], axis=2)
probs = tf.tile(tf.reduce_mean(tf.ones_like(nn_out1), axis=1, keepdims=True) / n, [1, n]) # trick to have ones of shape [None,1]
dist = tfp.distributions.MixtureSameFamily(
        mixture_distribution=tfp.distributions.Categorical(probs=probs),
        components_distribution=tfp.distributions.Deterministic(loc=outs))
x = dist.sample()
于 2019-06-21T22:44:26.527 回答
0

我认为您需要使用 tf.equal,因为 Tensor == 0 始终为 False。

不过,您可能希望单独使用 OneHotCategorical。对于培训,您也可以尝试使用 RelaxedOneHotCategorical。

于 2019-06-21T13:15:23.127 回答