如果输入的标签超出范围,有没有办法避免tfp.distributions.Categorical.log_prob
引发错误?
我将一批样本传递给该log_prob
方法,其中一些样本具有 value n_categories + 1
,当您从全零的概率分布中取样时,您会得到作为后备值的值。probs
我的批次中的一些概率分布全为零**。
dec_output, h_state, c_state = self.decoder(dec_inp, [h_state, c_state])
probs = self.attention(enc_output, dec_output, pointer_mask, len_mask)
distr = tfp.distributions.Categorical(probs=probs)
pointer = distr.sample()
log_prob = distr.log_prob(pointer) # log of the probability of choosing that action
我不在乎log_prob
在这些情况下我能从中获得什么价值,因为稍后我将掩盖它而不使用它。不确定一个fallback
值是否可以以某种方式实现。如果没有,是否有任何解决方法可以避免在我以图形模式(使用 )执行时引发错误@tf.function
?
**这是因为我正在使用 RNN 进行随机解码,该 RNN 是一批可变长度序列,一个 seq to seq 任务。