我有一个光线调谐分析对象,我能够从中获得最佳检查点:
analysis = tune_robert_asha(num_samples=2)
best_ckpt = analysis.best_checkpoint
但我无法用它恢复我的 pytorch 闪电模型。
我尝试:
MyLightningModel.load_from_checkpoint(
os.path.join(analysis.best_checkpoint, "checkpoint")
)
但是 MyLightningModel 在其构造函数中公开了一个配置,以便光线调谐可以为每个试验设置某些超参数:
class MyLightningModel (pl.LightningModule):
def __init__(self, config=None):
self.lr = config["lr"]
self.batch_size = config["batch_size"]
self.layer_size = config["layer_size"]
super(MyLightningModel , self).__init__()
self.lstm = nn.LSTM(768, self.layer_size, num_layers=1, bidirectional=False)
self.out = nn.Linear(self.layer_size, 1)
因此,当我尝试运行 load_from_checkpoint 时,由于配置未定义,MyLightningModel 的构造函数中出现错误:
TypeError Traceback (last last call last) in () 1 MyLightningModel.load_from_checkpoint( ----> 2 os.path.join(analysis.best_checkpoint, "checkpoint") 3 )
2 帧init (self, config) 3 def init (self, config=None): 4 ----> 5 self.lr = config["lr"] 6 self.batch_size = config["batch_size"] 7 self .layer_size = 配置["layer_size"]
TypeError:“NoneType”对象不可下标
这个应该怎么处理?