0

我有一个光线调谐分析对象,我能够从中获得最佳检查点:

analysis = tune_robert_asha(num_samples=2)
best_ckpt = analysis.best_checkpoint

但我无法用它恢复我的 pytorch 闪电模型。

我尝试:

MyLightningModel.load_from_checkpoint(
    os.path.join(analysis.best_checkpoint, "checkpoint")
)

但是 MyLightningModel 在其构造函数中公开了一个配置,以便光线调谐可以为每个试验设置某些超参数:

class MyLightningModel (pl.LightningModule):
    def __init__(self, config=None):
        self.lr = config["lr"]
        self.batch_size = config["batch_size"]
        self.layer_size = config["layer_size"]

        super(MyLightningModel , self).__init__()
        self.lstm = nn.LSTM(768, self.layer_size, num_layers=1, bidirectional=False)
        self.out = nn.Linear(self.layer_size, 1)

因此,当我尝试运行 load_from_checkpoint 时,由于配置未定义,MyLightningModel 的构造函数中出现错误:


TypeError Traceback (last last call last) in () 1 MyLightningModel.load_from_checkpoint( ----> 2 os.path.join(analysis.best_checkpoint, "checkpoint") 3 )

2 帧init (self, config) 3 def init (self, config=None): 4 ----> 5 self.lr = config["lr"] 6 self.batch_size = config["batch_size"] 7 self .layer_size = 配置["layer_size"]

TypeError:“NoneType”对象不可下标

这个应该怎么处理?

4

0 回答 0