2

在指针网络中,输出 logits 超过输入的长度。使用此类批次意味着将输入填充到批次输入的最大长度。现在,这一切都很好,直到我们必须计算损失。目前我正在做的是:

logits = stabilize(logits(inputs))     #[batch, max_length]. subtract max(logits) to stabilize
masks = masks(inputs)     #[batch, max_length]. 1 for actual inputs, 0 for padded locations
exp_logits = exp(logits)
exp_logits_masked = exp_logits*masks
probs = exp_logits_masked/sum(exp_logits_masked)

现在我使用这些概率来计算交叉熵

cross_entropy = sum_over_batches(probs[correct_class])

我能做得比这更好吗?关于处理指针网络的人通常如何完成它的任何想法?

如果我没有可变大小的输入,这一切都可以使用tf.nn.softmax_cross_entropy_with_logitslogits 和标签上的可调用对象(高度优化)来实现,但是可变长度会产生错误的结果,因为对于输入中的每个填充,softmax 计算的分母大 1。

4

1 回答 1

2

你的方法看起来很到位,据我所知,这也是在 RNN 单元中实现的方式。请注意,1x = dx 的导数和 0x = 0 的导数。这会产生您想要的结果,因为您正在对网络末端的梯度求和/平均。

您可能会考虑的唯一一件事是根据掩码值的数量重新调整损失。您可能会注意到,当有 0 个掩码值时,您的渐变的幅度与使用许多掩码值时的幅度略有不同。我不清楚这是否会产生重大影响,但可能会产生非常小的影响。

否则,我自己也使用同样的技术取得了巨大的成功,所以我在这里说你走在正确的轨道上。

于 2018-03-26T16:17:43.013 回答