我对 huggingface 的 distillBERT 工作很感兴趣,通过查看他们的代码 ( https://github.com/huggingface/transformers/blob/master/examples/distillation/train.py ),我发现如果使用 roBERTa 作为学生模型,他们会冻结位置嵌入,我想知道这是为了什么?
def freeze_pos_embeddings(student, args):
if args.student_type == "roberta":
student.roberta.embeddings.position_embeddings.weight.requires_grad = False
elif args.student_type == "gpt2":
student.transformer.wpe.weight.requires_grad = False
我理解冻结 token_type_embeddings 的原因,因为 roBERTa 从未使用过段嵌入,但为什么要定位嵌入呢?
非常感谢您的帮助!