Huggigface BERT 实现有一个技巧,可以从优化器中删除池化器。
# hack to remove pooler, which is not used
# thus it produce None grad that break apex
param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
我们正在尝试在 huggingface bert 模型上运行 pretrining。如果不应用此 pooler hack,则代码在训练后期总是会发散。我还看到在分类过程中使用了池化层。
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
池化层是一个带有 tanh 激活的 FFN
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
我的问题是为什么这个 pooler hack 解决了数字不稳定性?
池化器出现的问题