0

我有一个微调的简单变压器表示模型。现在我想仅以 pickle 格式保存池层的权重,并将其放在我正在设计的另一个自定义自动编码器的池层中。我如何使用 pytorch 和 python 来做到这一点?

4

1 回答 1

0

state_dict每个 PyTorch 模块旁边都有一个对象调用,它允许将任何参数映射到其对应的张量变量(更多信息请参见此处)。使用此实用程序,您可以轻松保存和加载参数,但请记住,您必须事先确定您想要在语义上(从机器学习的角度)和语法上(形状兼容性和......)做什么!pooling下面的实现将使用我们之前保存的模型中的相应变量替换名称中带有单词的任何参数。

finetuned_model = BertLMHeadModel.from_pretrained('bert-base-cased')
torch.save(finetuned_model.state_dict(), "finetuned_model.pth")
finetuned_model_state_dict = torch.load("finetuned_model.pth")
new_model = BertLMHeadModel.from_pretrained('bert-base-cased')
new_model_state_dict = new_model.state_dict()
for key, value in new_model_state_dict.items():
  if key.find('pooling')!=-1:
    new_model_state_dict.update({key: value})
于 2021-12-06T13:34:36.540 回答