0

jsonnet我希望通过使用配置文件在标准 AllenNLP 模型训练中禁用序列化所有模型/状态权重。

原因是我正在使用 Optuna 运行自动超参数优化。测试数十个模型很快就会填满驱动器。我已经通过设置禁用了检查num_serialized_models_to_keep0

trainer +: {
    checkpointer +: {
        num_serialized_models_to_keep: 0,
    },

我不希望设置serialization_dir为,None因为我仍然想要关于记录中间指标等的默认行为。我只想禁用默认模型状态、训练状态和最佳模型权重写入

除了我上面设置的选项之外,是否有任何默认的训练器或检查点选项来禁用模型权重的所有序列化?我检查了 API 文档和网页,但找不到任何内容。

如果我需要自己定义这样一个选项的功能,我应该在我的模型子类中覆盖 AllenNLP 中的哪些基本函数?

或者,当训练结束时,它们对清理中间模型状态有什么用吗?

编辑: @petew 的答案显示了自定义检查点的解决方案,但我不清楚如何使此代码可allennlp train用于我的用例。

我希望从配置文件中调用 custom_checkpointer,如下所示:

trainer +: {
    checkpointer +: {
        type: empty,
    },

调用时加载检查点的最佳做法是什么allennlp train --include-package <$my_package>

我的 my_package 包含子目录中的子模块,例如my_package/modelss 和my_package/training. 我想将自定义检查点代码放在my_package/training/custom_checkpointer.py 我的主模型位于my_package/models/main_model.py. 我是否必须在 main_model 类中编辑或导入任何代码/函数才能使用自定义检查点?

4

1 回答 1

1

您可以创建并注册一个基本上什么都不做的自定义检查点:

@Checkpointer.register("empty")
class EmptyCheckpointer(Registrable):
    def maybe_save_checkpoint(
        self, trainer: "allennlp.training.trainer.Trainer", epoch: int, batches_this_epoch: int
    ) -> None:
        pass

    def save_checkpoint(
        self,
        epoch: Union[int, str],
        trainer: "allennlp.training.trainer.Trainer",
        is_best_so_far: bool = False,
        save_model_only=False,
    ) -> None:
        pass

    def find_latest_checkpoint(self) -> Optional[Tuple[str, str]]:
        pass

    def restore_checkpoint(self) -> Tuple[Dict[str, Any], Dict[str, Any]]:
        return {}, {}

    def best_model_state(self) -> Dict[str, Any]:
        return {}
于 2020-10-12T17:59:56.383 回答