2

Huggigface BERT 实现有一个技巧,可以从优化器中删除池化器。

https://github.com/huggingface/transformers/blob/b832d5bb8a6dfc5965015b828e577677eace601e/examples/run_squad.py#L927

# hack to remove pooler, which is not used
# thus it produce None grad that break apex
param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

我们正在尝试在 huggingface bert 模型上运行 pretrining。如果不应用此 pooler hack,则代码在训练后期总是会发散。我还看到在分类过程中使用了池化层。

pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)

池化层是一个带有 tanh 激活的 FFN

class BertPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

我的问题是为什么这个 pooler hack 解决了数字不稳定性?

池化器出现的问题

递减损失定标器

4

1 回答 1

-1

有相当多的资源可能比我更好地解决这个问题,例如参见此处此处

具体来说,问题在于您正在处理消失(或爆炸)梯度,特别是当使用在非常小/大输入的任一方向上变平的损失函数时,sigmoid 和 tanh 都是这种情况(这里唯一的区别是它们的输出所在的范围,分别是[0, 1][-1, 1]

此外,如果你有一个低精度的小数,就像 APEX 的情况一样,那么梯度消失行为很可能已经出现在相对中等的输出中,因为精度限制了它能够从零区分的数字。解决这个问题的一种方法是使用具有严格非零且易于计算的导数的函数,例如 Leaky ReLU,或者干脆完全避免激活函数(我假设这是 huggingface 在这里所做的)。

请注意,梯度爆炸的问题通常没有那么悲惨,因为我们可以应用梯度裁剪(将其限制为固定的最大尺寸),但原理是相同的。另一方面,对于零梯度,没有这么简单的解决方法,因为它会导致你的神经元“死亡”(零回流没有发生主动学习),这就是为什么我假设你看到了发散行为。

于 2020-03-20T10:13:08.437 回答