1

我对 huggingface 的 distillBERT 工作很感兴趣,通过查看他们的代码 ( https://github.com/huggingface/transformers/blob/master/examples/distillation/train.py ),我发现如果使用 roBERTa 作为学生模型,他们会冻结位置嵌入,我想知道这是为了什么?

def freeze_pos_embeddings(student, args):
    if args.student_type == "roberta":
        student.roberta.embeddings.position_embeddings.weight.requires_grad = False
    elif args.student_type == "gpt2":
        student.transformer.wpe.weight.requires_grad = False

我理解冻结 token_type_embeddings 的原因,因为 roBERTa 从未使用过段嵌入,但为什么要定位嵌入呢?
非常感谢您的帮助!

4

1 回答 1

2

在大多数(甚至所有)常用的 Transformer 中,位置嵌入没有经过训练,而是使用解析描述的函数定义(Attention 第 6 页上的未编号方程就是你需要的全部纸张):

在此处输入图像描述

为了节省Transformer 包中的计算时间,它们被预先计算到 512 的长度并存储为用作缓存的变量,在训练期间不应更改。

不训练位置嵌入的原因是后面位置的嵌入会训练不足,但是通过巧妙地分析定义位置嵌入,网络可以学习方程背后的规律性并更容易泛化更长的序列。

于 2020-03-23T08:05:47.040 回答