0

如果输入的标签超出范围,有没有办法避免tfp.distributions.Categorical.log_prob引发错误?

我将一批样本传递给该log_prob方法,其中一些样本具有 value n_categories + 1,当您从全零的概率分布中取样时,您会得到作为后备值的值。probs我的批次中的一些概率分布全为零**。

dec_output, h_state, c_state = self.decoder(dec_inp, [h_state, c_state])
probs = self.attention(enc_output, dec_output, pointer_mask, len_mask)
distr = tfp.distributions.Categorical(probs=probs)
pointer = distr.sample()
log_prob = distr.log_prob(pointer) # log of the probability of choosing that action

我不在乎log_prob在这些情况下我能从中获得什么价值,因为稍后我将掩盖它而不使用它。不确定一个fallback值是否可以以某种方式实现。如果没有,是否有任何解决方法可以避免在我以图形模式(使用 )执行时引发错误@tf.function

**这是因为我正在使用 RNN 进行随机解码,该 RNN 是一批可变长度序列,一个 seq to seq 任务。

4

1 回答 1

1

如果您可以屏蔽 log_prob,您也可以将概率屏蔽为 1 / n。请注意,使用 Categorical 的 logits 参数化并放弃(可能)上游 softmax 激活在数值上更稳定。

于 2021-05-26T14:40:06.053 回答