在指针网络中,输出 logits 超过输入的长度。使用此类批次意味着将输入填充到批次输入的最大长度。现在,这一切都很好,直到我们必须计算损失。目前我正在做的是:
logits = stabilize(logits(inputs)) #[batch, max_length]. subtract max(logits) to stabilize
masks = masks(inputs) #[batch, max_length]. 1 for actual inputs, 0 for padded locations
exp_logits = exp(logits)
exp_logits_masked = exp_logits*masks
probs = exp_logits_masked/sum(exp_logits_masked)
现在我使用这些概率来计算交叉熵
cross_entropy = sum_over_batches(probs[correct_class])
我能做得比这更好吗?关于处理指针网络的人通常如何完成它的任何想法?
如果我没有可变大小的输入,这一切都可以使用tf.nn.softmax_cross_entropy_with_logits
logits 和标签上的可调用对象(高度优化)来实现,但是可变长度会产生错误的结果,因为对于输入中的每个填充,softmax 计算的分母大 1。