是否有约定在 PyTorch Lightning 中实现某种predict()
方法,在使用 执行实际预测之前进行预处理forward()
?
在我的例子中,我有一个由嵌入层和几个全连接层组成的文本分类器。文本在传递到嵌入层之前需要进行标记。在训练和评估期间,LightningDataModule
'setup()
方法可以完成工作。
现在,我想知道在生产过程中进行推理的最佳实践是什么。我可以在其中添加一个predict()
方法,LightningModule
我可以在其中编写与LightningDataModule.setup()
. 但是,当然,我不想复制代码。
在官方 PyTorch Lightning 文档中链接的这个社区示例prepare_sample()
项目中,作者定义了一个由他们的函数LightningModule
使用的predict()
函数,并且也传递给LightningDataModule
.
这是处理预处理的正确方法吗?另外,为什么没有prepare_sample()
或predict()
在LightningModule
?对我来说,这似乎是一个常见的用例,例如:
model = load_model('data/model.ckpt') # load pre-trained model, analyzes user reviews
user_input = input('Your movie review > ')
predicted_rating = model.predict(user_input) # e.g. "I liked the movie pretty much." -> 4 stars
print('Predicted rating: %s/5 stars' % predicted_rating)
现在我考虑了一下,predict()
也应该forward()
以与评估代码相同的方式处理结果,例如选择具有最高输出的类或选择输出大于某个阈值的所有类 - 一些不应重复的代码。