由于 Pytorch Lightning 为模型检查点提供了自动保存功能,因此我使用它来保存 top-k 最佳模型。特别是在培训师设置中,
checkpoint_callback = ModelCheckpoint(
monitor='val_acc',
dirpath='checkpoints/',
filename='{epoch:02d}-{val_acc:.2f}',
save_top_k=5,
mode='max',
)
这运行良好,但它不保存模型对象的某些属性。我的模型在每个训练时期结束时都会存储一些张量,这样
class SampleNet(pl.LightningModule):
def __init__(self):
super().__init__()
self.save_hyperparameters()
self.layer = torch.nn.Linear(100, 1)
self.loss = torch.nn.CrossEntropy()
self.some_data = None # Initialize as None
def training_step(self, batch):
x, t = batch
out = self.layer(x)
loss = self.loss(out, t)
results = {'loss': loss}
return results
def training_epoch_end(self, outputs):
self.some_data = some_tensor_object
这是一个简化的示例,但我希望上面制作的检查点文件checkpoint_callback
记住该属性self.some_data
,但是当我从检查点加载模型时,它总是重置为None
. 我确认它在培训期间成功更新。
我尝试不将其初始化为 None ,init
但加载模型时该属性将消失。
我想避免将属性保存为不同的pt
文件,因为它与模型配置相关联,因此我稍后需要手动将文件与相应的检查点文件匹配。
是否可以在检查点文件中包含这样的张量属性?