0

我正在尝试实现以下函数来保存 model_state 检查点:

def train_epoch(self):
for epoch in tqdm.trange(self.epoch, self.max_epoch, desc='Train Epoch', ncols=100):
    self.epoch = epoch      # increments the epoch of Trainer
    checkpoint = {} # fixme: here checkpoint!!!
    # model_save_criteria = self.model_save_criteria
    self.train()
    if epoch % 1 == 0:
        self.validate(checkpoint) 
    checkpoint_latest = {
        'epoch': self.epoch,
        'arch': self.model.__class__.__name__,
        'model_state_dict': self.model.state_dict(),
        'optimizer_state_dict': self.optim.state_dict(),
        'model_save_criteria': self.model_save_criteria
    }
    checkpoint['checkpoint_latest'] = checkpoint_latest
    torch.save(checkpoint, self.model_pth)

以前我只是通过运行一个 for 循环来做同样的事情:

train_states = {}
for epoch in range(max_epochs):
    running_loss = 0
    time_batch_start = time.time()
    model.train()
    for bIdx, sample in enumerate(train_loader):
        ...
        train...
        validation...
        train_states_latest = {
          'epoch': epoch + 1,
          'model_state_dict': model.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'model_save_criteria': chosen_criteria}
        train_states['train_states_latest'] = train_states_latest
        torch.save(train_states, FILEPATH_MODEL_SAVE)

有没有办法在checkpoint={}每个循环中启动和更新它?或者checkpoint={}在每个时期都很好,因为模型本身持有state_dict(). 只是我每次都覆盖检查点。

4

1 回答 1

0

您可以通过简单地更改 FILEPATH_MODEL_SAVE 路径并让该路径包含有关纪元或迭代次数的信息来避免覆盖检查点。例如(使用您的原始代码),

train_states = {}
for epoch in range(max_epochs):
    running_loss = 0
    time_batch_start = time.time()
    model.train()
    for bIdx, sample in enumerate(train_loader):
        ...
        train...
        validation...
        train_states_latest = {
          'epoch': epoch + 1,
          'model_state_dict': model.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'model_save_criteria': chosen_criteria}
        train_states['train_states_latest'] = train_states_latest
        

        # This is the code you can add
        FILEPATH_MODEL_SAVE = "Epoch{}batch{}model_weights.pth".format(epoch, bIdx)
        torch.save(train_states, FILEPATH_MODEL_SAVE)


使用 torch.save 上面的这个新代码,您可以避免覆盖检查点。

萨尔塔克

于 2021-07-24T20:36:03.417 回答