1

我有一个分两个阶段训练的 DL 模型:

  1. 使用合成数据进行预训练
  2. 使用真实世界数据进行微调

模型在第 1 阶段后保存。在第 2 阶段,从 .pth 文件创建并加载模型,并使用新数据再次开始训练。我想申请 QAT,但我在第 2 阶段遇到了问题。损失真的很大(就像没有 QAT 的合成训练的开始 - 应该小 60 倍以上)。我怀疑这是观察者重新启动和冻结的错。问题是:加载 QAT 模型并继续训练的正确方法是什么?

第 1 阶段的代码:

import torch

...
self.create_net()
self.net.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(self.net, inplace=True)
# Skip fuse Conv-Bn-ReLU
...

# In training loop
if train_iter == 40_000:
    print("Freeze batch norm mean and variance estimates")
    self.net.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
if train_iter == 50_000:
    print("Freeze quantizer parameters")
    self.net.apply(torch.quantization.disable_observer)
...

# After training
# Do not convert to quantized model since it'll be trained again
torch.save(self.net.state_dict(), str(filepath))

第 2 阶段的代码:


import torch

...
self.create_net()
custom_load_state_dict(self.net, torch.load(str(filepath), map_location="cpu"))
self.net.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(self.net, inplace=True)
# Freeze observers and bn immediately after model load
self.net.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
self.net.apply(torch.quantization.disable_observer)
...

# Another file
def custom_load_state_dict(target_module: torch.nn.Module, source_state_dict: dict) -> None:
    target_state_dict = target_module.state_dict()
    for key, source_tensor in source_state_dict.items():
        if key in target_state_dict:
            if target_state_dict[key].shape == source_tensor.shape:
                target_state_dict[key] = source_tensor
    unmatched_keys = target_module.load_state_dict(target_state_dict)
    if unmatched_keys:
        print(f'Unmatched keys during model loading:\n{unmatched_keys}')

我试过先初始化 QAT,然后加载权重,但它没有改变任何东西。我也尝试过手动将模型转换为 QAT:

# Instead:
torch.quantization.prepare_qat(self.net, inplace=True)

# Do:
from torch.ao.quantization import get_default_qat_module_mappings, propagate_qconfig_, convert

mapping = get_default_qat_module_mappings()
propagate_qconfig_(model, qconfig_dict=None)
convert(model, mapping=mapping, inplace=True, remove_qconfig=False)

但是经过训练后,当我尝试转换为量化模型时,它会引发错误:

# Throws error - missing observers
quantized_model = torch.quantization.convert(quantized_model, inplace=True)

当我在模型加载后跳过冻结 BN 和观察者时,它似乎工作正常。但它是正确的吗?这不会破坏之前学习的量化水平吗?

4

0 回答 0