我将模型、优化器、调度器和缩放器保存在一般检查点中。
现在,当我加载它们时,它们会正确加载,但在第一次迭代后scaler.step(optimizer)
会引发此错误:
Traceback (most recent call last):
File "HistNet/trainloop.py", line 92, in <module>
scaler.step(optimizer)
File "/opt/conda/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 333, in step
retval = optimizer.step(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/optim/lr_scheduler.py", line 65, in wrapper
return wrapped(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/optim/optimizer.py", line 89, in wrapper
return func(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/optim/adam.py", line 108, in step
F.adam(params_with_grad,
File "/opt/conda/lib/python3.8/site-packages/torch/optim/functional.py", line 86, in adam
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
RuntimeError: The size of tensor a (32) must match the size of tensor b (64) at non-singleton dimension 0
现在我真的不明白为什么所有事物的形状都不匹配。我所做的一切都与官方文档类似,这是我的代码的缩短版本:
dataloader = DataLoader(Dataset)
model1 = model1()
optimizer = optim.Adam(parameters, lr, betas)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: decay_rate**epoch)
scaler = amp.GradScaler()
if resume: epoch_resume = load_checkpoint(path, model1, optimizer, scheduler, scaler)
for epoch in trange(epoch_resume, config['epochs']+1, desc='Epochs'):
for content_image, style_image in tqdm(dataloader, desc='Dataloader'):
content_image, style_image = content_image.to(device), style_image.to(device)
with amp.autocast():
content_image = TF.rgb_to_grayscale(content_image)
s = TF.rgb_to_grayscale(style_image)
deformation_field = model1(s, content_image)
output_image = F.grid_sample(content_image, deformation_field.float(), align_corners=False)
loss_after = cost_function(output_image, s, device=device)
loss_list += [loss_after]
scaler.scale(loss_after).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
scheduler.step()
torch.save({
'epoch': epoch,
'model1_state_dict': model1.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'scaler_state_dict': scaler.state_dict(),
}, path)
def load_checkpoint(checkpoint_path, model1, optimizer, scheduler, scaler):
checkpoint = torch.load(checkpoint_path)
model1.load_state_dict(checkpoint['model1_state_dict'])
model1.train()
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
scaler.load_state_dict(checkpoint['scaler_state_dict'])
epoch = checkpoint['epoch']
return epoch+1