1

我最初正在尝试在烧瓶中加载 GPT2 微调模型。在初始化函数期间使用以下方法加载模型:

app.modelgpt2 = torch.load('models/model_gpt2.pt', map_location=torch.device('cpu'))
app.modelgpt2tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

但是在执行如下片段中的预测任务时:

from flask import current_app
input_ids = current_app.modelgpt2tokenizer.encode("sample sentence here", return_tensors='pt')
sample_outputs = current_app.modelgpt2.generate(input_ids,
                                                do_sample=True,
                                                top_k=50,
                                                min_length=30,
                                                max_length=300,
                                                top_p=0.95,
                                                temperature=0.7,
                                                num_return_sequences=1)

如问题中所述,它会引发以下错误: AttributeError: 'GPT2Model' object has no attribute 'gradient_checkpointing'

错误跟踪从model.generate函数开始列出:文件“/venv/lib/python3.8/site-packages/torch/autograd/grad_mode.py”,第 28 行,在 decorate_context return func(*args, **kwargs)

文件“/venv/lib/python3.8/site-packages/transformers/generation_utils.py”,第 1017 行,在生成返回 self.sample(

示例输出中的文件“/venv/lib/python3.8/site-packages/transformers/generation_utils.py”,第 1531 行 = self(

_call_impl 中的文件“/venv/lib/python3.8/site-packages/torch/nn/modules/module.py”,第 1102 行 return forward_call(*input, **kwargs)

文件“/venv/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py”,第 1044 行,在正向 transformer_outputs = self.transformer(

_call_impl 中的文件“/venv/lib/python3.8/site-packages/torch/nn/modules/module.py”,第 1102 行 return forward_call(*input, **kwargs)

文件“/venv/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py”,第 861 行,向前打印(self.gradient_checkpointing)

文件“/venv/lib/python3.8/site-packages/torch/nn/modules/module.py”,第 1177 行,在getattr raise AttributeError("'{}' object has no attribute '{}'".format (

AttributeError:“GPT2Model”对象没有属性“gradient_checkpointing”

选中modeling_gpt2.py,默认情况下在类的构造函数中self.gradient_checkpointing设置。False

4

1 回答 1

0

只有当框架使用 venv 或部署框架(如 uWSGI 或 gunicorn)运行时,才会发现此问题。当使用变形金刚 4.10.0 版而不是最新的包时,该问题已得到解决。

于 2021-11-20T11:21:43.510 回答