我有一个在 Google Cloud Run 上运行的 Flask 应用程序,它需要下载一个大型模型(来自 huggingface 的 GPT-2)。这需要一段时间来下载,所以我正在尝试设置它只在部署时下载,然后只为后续访问提供服务。也就是说,我的主烧瓶应用 app.py 导入的脚本中有以下代码:
import torch
# from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import AutoTokenizer, AutoModelWithLMHead
# Disable gradient calculation - Useful for inference
torch.set_grad_enabled(False)
# Check if gpu or cpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load tokenizer and model
try:
tokenizer = AutoTokenizer.from_pretrained("./gpt2-xl")
model = AutoModelWithLMHead.from_pretrained("./gpt2-xl")
except Exception as e:
print('no model found! Downloading....')
AutoTokenizer.from_pretrained('gpt2').save_pretrained('./gpt2-xl')
AutoModelWithLMHead.from_pretrained('gpt2').save_pretrained('./gpt2-xl')
tokenizer = AutoTokenizer.from_pretrained("./gpt2-xl")
model = AutoModelWithLMHead.from_pretrained("./gpt2-xl")
model = model.to(device)
这基本上会尝试加载下载的模型,如果失败,它会下载模型的新副本。我将自动缩放设置为至少 1,我认为这意味着某些东西会一直在运行,因此即使在活动之后下载的文件也会持续存在。但是当有些人尝试使用它时,它必须重新下载模型,这会冻结应用程序。我正在尝试重新创建类似这个应用程序https://text-generator-gpt2-app-6q7gvhilqq-lz.a.run.app/的东西,它似乎没有相同的加载时间问题。在烧瓶应用程序本身中,我有以下内容:
@app.route('/')
@cross_origin()
def index():
prompt = wp[random.randint(0, len(wp)-1)]
res = generate(prompt, size=75)
generated = res.split(prompt)[-1] + '\n \n...TO BE CONTINUED'
#generated = prompt
return flask.render_template('main.html', prompt = prompt, output = generated)
if __name__ == "__main__":
app.run(host='0.0.0.0',
debug=True,
port=PORT)
但它似乎每隔几个小时就重新下载一次模型......我怎样才能避免让应用程序重新下载模型和那些想尝试它的人冻结应用程序?