0

我正在实现一个程序来从分类分布中采样整数,其中每个整数都与一个概率相关联。我需要确保这个程序是可微的,以便可以应用反向传播。我发现tf.contrib.distributions.RelaxedOneHotCategorical这非常接近我想要实现的目标。

然而,sample这个类的方法返回一个单热向量,而不是一个整数。如何编写一个既可微又返回整数/标量而不是向量的程序?

4

3 回答 3

0

您无法以可微分的方式获得您想要的东西,因为argmax不可微分,这就是首先创建 Gumbel-Softmax 分布的原因。例如,这允许您使用语言模型的输出作为生成对抗网络中鉴别器的输入,因为随着温度的变化,激活接近一个单热向量。

如果您只需要在推理或测试时检索最大元素,您可以使用tf.math.argmax. 但是没有办法以可区分的方式做到这一点。

于 2021-05-29T14:01:33.770 回答
0

您可以对松弛的一个热向量与 的向量进行点积[1 2 3 4 ... n]。结果将为您提供所需的标量。

例如,如果您的一个热向量是[0 0 0 1],那么dot([0 0 0 1],[1 2 3 4])将给您 4 这就是您要寻找的。

于 2019-06-03T16:14:03.710 回答
0

RelaxedOneHotCategorical 实际上可微的原因与它返回浮点数的 softmax 向量而不是 argmax int 索引的事实有关。如果你想要的只是最大元素的索引,你不妨使用分类。

于 2019-03-07T21:01:02.053 回答