我正在实现一个程序来从分类分布中采样整数,其中每个整数都与一个概率相关联。我需要确保这个程序是可微的,以便可以应用反向传播。我发现tf.contrib.distributions.RelaxedOneHotCategorical
这非常接近我想要实现的目标。
然而,sample
这个类的方法返回一个单热向量,而不是一个整数。如何编写一个既可微又返回整数/标量而不是向量的程序?
我正在实现一个程序来从分类分布中采样整数,其中每个整数都与一个概率相关联。我需要确保这个程序是可微的,以便可以应用反向传播。我发现tf.contrib.distributions.RelaxedOneHotCategorical
这非常接近我想要实现的目标。
然而,sample
这个类的方法返回一个单热向量,而不是一个整数。如何编写一个既可微又返回整数/标量而不是向量的程序?
您无法以可微分的方式获得您想要的东西,因为argmax
不可微分,这就是首先创建 Gumbel-Softmax 分布的原因。例如,这允许您使用语言模型的输出作为生成对抗网络中鉴别器的输入,因为随着温度的变化,激活接近一个单热向量。
如果您只需要在推理或测试时检索最大元素,您可以使用tf.math.argmax
. 但是没有办法以可区分的方式做到这一点。
您可以对松弛的一个热向量与 的向量进行点积[1 2 3 4 ... n]
。结果将为您提供所需的标量。
例如,如果您的一个热向量是[0 0 0 1]
,那么dot([0 0 0 1],[1 2 3 4])
将给您 4 这就是您要寻找的。
RelaxedOneHotCategorical 实际上可微的原因与它返回浮点数的 softmax 向量而不是 argmax int 索引的事实有关。如果你想要的只是最大元素的索引,你不妨使用分类。