我有n
网络,每个网络都有相同的输入/输出。我想根据分类分布随机选择一个输出。Tfp.Categorical只输出整数,我试图做类似的事情
act_dist = tfp.distributions.Categorical(logits=act_logits) # act_logits are all the same, so the distribution is uniform
rand_out = act_dist.sample()
x = nn_out1 * tf.cast(rand_out == 0., dtype=tf.float32) + ... # for all my n networks
但是rand_out == 0.
总是假的,以及其他条件。
有什么想法可以实现我的需要吗?