model = SqueezeNext()
model = model.to(device)
def load_checkpoint(model, optimizer, losslogger, filename='SqNxt_23_1x_Cifar.ckpt'):
# Note: Input model & optimizer should be pre-defined. This routine only updates their states.
start_epoch = 0
if os.path.isfile(filename):
print("=> loading checkpoint '{}'".format(filename))
checkpoint = torch.load(filename)
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
losslogger = checkpoint['losslogger']
print("=> loaded checkpoint '{}' (epoch {})"
.format(filename, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(filename))
return model, optimizer, start_epoch, losslogger
model, optimizer, start_epoch, losslogger = load_checkpoint(model, optimizer, losslogger)
TypeError: Traceback (last last call last) in () 41 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=80, num_workers=8, shuffle=False) 42 ---> 43 model = SqueezeNext() 44 model = model.to(device) 45 def load_checkpoint(model, optimizer, losslogger, filename='SqNxt_23_1x_Cifar.ckpt'): TypeError: init () missing 3 required positional arguments: 'width_x', 'blocks', 和 'num_classes'
我认为我没有以正确的方式实现这一点!!