在进行量化感知训练时冻结 BN 统计数据是谷歌量化白皮书中介绍的一种常见训练技术。而PyTorch 官方教程的代码片段也展示了如何在 PyTorch 中做到这一点:
num_train_batches = 20
# QAT takes time and one needs to train over a few epochs.
# Train and check accuracy after each epoch
for nepoch in range(8):
train_one_epoch(qat_model, criterion, optimizer, data_loader, torch.device('cpu'), num_train_batches)
if nepoch > 3:
# Freeze quantizer parameters
qat_model.apply(torch.quantization.disable_observer)
if nepoch > 2:
# Freeze batch norm mean and variance estimates
qat_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
# Check the accuracy after each epoch
quantized_model = torch.quantization.convert(qat_model.eval(), inplace=False)
quantized_model.eval()
top1, top5 = evaluate(quantized_model,criterion, data_loader_test, neval_batches=num_eval_batches)
print('Epoch %d :Evaluation accuracy on %d images, %2.2f'%(nepoch, num_eval_batches * eval_batch_size, top1.avg))
但是,正如其标题所示,这是“急切模式”的代码片段。但我正在尝试使用原型 FX Graph 模式进行量化感知训练。PyTorch官方教程只是展示了使用FX Graph模式时如何进行PTQ,仅简单介绍FX Graph模式下的QAT。
#
# quantization aware training for static quantization
#
model_to_quantize = copy.deepcopy(model_fp)
qconfig_dict = {"": torch.quantization.get_default_qat_qconfig('qnnpack')}
model_to_quantize.train()
# prepare
model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_dict)
# training loop (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)
#
# fusion
#
model_to_quantize = copy.deepcopy(model_fp)
model_fused = quantize_fx.fuse_fx(model_to_quantize)
正如上面的代码片段所示,它只是省略了“训练循环”。我想知道的是,在急切模式下冻结 BN 统计数据的 APItorch.nn.intrinsic.qat.freeze_bn_stats
在 FX Graph 模式下是否仍然可用,这意味着我可以使用model_prepared_fx.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
) 来实现我的目标?还是我应该使用另一种机制来做到这一点?