QAT快速上手(fx mode)

目录

  • 1. 必要步骤速览-

  • 2. 各步骤详解-

    • 2.1 浮点模型准备-

    • 2.2 数据校准-

    • 2.3 量化训练-

    • 2.4 定点转换-

    • 2.5 模型编译-

  • 3. 常见问题-


大家会选择QAT方案,想必对于“量化”应该是已经有一些基本概念了(若无也没关系,感兴趣的话推荐阅读一下这篇帖子:神经网络量化背景)。通常来说,还是最推荐大家优先尝试 PTQ(post training quantization),毕竟相较于QAT而言上手容易代价小,对浮点模型代码无侵入,且地平线PTQ方案精度保持情况处于业界领先水平,对大多数常见视觉任务都可将量化损失控制在1%以内。但不管是再好的离线量化算法,都无法保证可以让所有的模型量化精度都达到业务需要。并且PTQ仅用到了少量的校准数据,即使在验证集上精度指标差不多,但由于量化不可避免会引入的信息丢失,仍然有可能导致模型的泛化性受到影响。而 QAT(quantization aware training)则是把量化误差作为训练的噪声,让模型在训练的过程中不断去学习适应这种噪声,最终我们得到的模型参数也会对量化更具鲁棒性。-
理论多说无益,那浮点模型到 QAT模型的“最后一公里”该如何快速抵达呢?Pytorch 1.8之后推出的torch.fx(官方文档)可以自动跟踪模型forward过程,能大大降低QAT的使用难度,对应的量化方案:FX Graph Mode Quantization,相较于 Eager Mode Quantization 其自动化程度会高很多,相应地操作复杂度也低一些,但也需要用户适当的调整模型以使得模型满足“symbolically traceable”。下面我们就直接来看一下具体步骤吧:

1. 必要步骤速览

整个QAT方案从浮点到部署模型就包括五个步骤:浮点模型准备、数据校准、量化训练(可选)、定点转换、模型编译。必要的步骤和示例代码如下所示,每个步骤的详细说明和注意事项可参考后文讲解。完整的示例推荐参考OE开发包中/ddk/samples/ai_toolchain/horizon_model_train_sample/plugin_basic目录下的fx_mode.py脚本。-

强烈建议在量化训练前(甚至是浮点模型设计阶段)先跳过训练过程完整走完prepare->convert->check步骤,确保模型可被硬件支持。

from horizon_plugin_pytorch.quantization import (
    convert_fx, 
    prepare_qat_fx,
    set_fake_quantize,
    FakeQuantState,
    check_model,
    compile_model,
)
from horizon_plugin_pytorch.quantization.qconfig import (
    default_calib_8bit_fake_quant_qconfig,
    default_qat_8bit_fake_quant_qconfig,
    default_calib_8bit_weight_32bit_out_fake_quant_qconfig,
    default_qat_8bit_weight_32bit_out_fake_quant_qconfig
)
from horizon_plugin_pytorch.march import March, set_march
set_march(March.BAYES) # 在prepare之前设置计算架构
# 1.准备浮点模型
float_model = load_float_model(pretrain=True) 
# 2.数据校准
calib_model = prepare_qat_fx(
    # 关于为何要使用deepcopy,可查看后文第三章常见问题 1
    copy.deepcopy(float_model),
    {
        "": default_calib_8bit_fake_quant_qconfig,
        "module_name": {
        # 关于为何要开启高精度输出可查看后文第三章常见问题 2
            "classifier": default_calib_8bit_weight_32bit_out_fake_quant_qconfig,
         },
     },
)
calib_model.eval()
# 关于FakeQuantState相关说明请参考后文第三章常见问题 3
set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)
calibrate(calib_model)
# 评测数据校准精度
calib_model.eval()
set_fake_quantize(calib_model, FakeQuantState.VALIDATION)
evaluate(calib_model)
torch.save(calib_model.state_dict(), "calib-checkpoint.ckpt")
# 3.量化训练(若数据校准精度已达标,可跳过该步骤)
qat_model = prepare_qat_fx(
    copy.deepcopy(float_model),
    {
        "": default_qat_8bit_fake_quant_qconfig,
        "module_name": {
            "classifier": default_qat_8bit_weight_32bit_out_fake_quant_qconfig,
         },
     },
) 
qat_model.load_state_dict(calib_model.state_dict())
qat_model.train()
set_fake_quantize(qat_model, FakeQuantState.QAT)
train(qat_model)
# 评测量化训练精度
qat_model.eval()
set_fake_quantize(qat_model, FakeQuantState.VALIDATION)
evaluate(qat_model)
# 4.定点转换
base_model = qat_model # 校准精度满足预期时使用 calib_model
quantized_model = convert_fx(base_model)
# 评测定点精度(正常情况下该模型精度与hbm模型精度一致,因此请以此来评价最终部署精度)
evaluate(quantized_model)
# 5.模型编译(注意模型和数据要放在 CPU 上)
script_model = torch.jit.trace(quantized_model.cpu(), example_input)
check_model(script_model, [example_input])
compile_model(script_model,[example_input],hbm="model.hbm",input_source="pyramid",opt=3)

2. 各步骤详解

2.1 浮点模型准备

**a. 请使用足够的数据量将浮点模型正常训练至收敛后再进行量化训练。-

b. 强烈建议对输入数据进行归一化处理,有利于浮点收敛的同时也可使得模型对量化更友好。-

c. 建议您在浮点模型设计阶段对照算子支持列表,避免使用不支持的算子导致后续prepare qat或者编译报错。-

d. 若模型中使用了cpu算子,且您确认需要将其编译进模型中,可参考用户手册 4.2.4.4. 异构模型指南进行转换编译。-

e. 更多关于如何搭建量化友好模型的说明可参考用户手册 4.2.4.1浮点模型的要求**

虽然fx mode相较于eager mode对原始浮点模型代码侵入较小,但仍然需要对浮点模型做一些必要的改造以支持后续量化操作。

  • 在模型输入前插入 QuantStub节点,在模型输出后插入 DequantStub节点。有如下注意事项:

    • 多个输入仅在 scale 相同时可以共享 QuantStub,否则请为每个输入定义单独的 QuantStub
    • 建议使用horizon_plugin_pytorch.quantization.QuantStub 默认动态统计输入scale,若是可提前计算出scale的场景建议手动设置scale(例如bev模型的homo矩阵),对应的公版接口torch.quantization.QuantStub不支持手动设置。
  • 建议模型前后处理、loss等不需要量化的部分不要写在模型forward函数里,避免被误插入伪量化节点,进而影响模型精度。

  • 对于动态控制流以及一些python内置函数等symbolic trace不支持的操作(可查看官方说明),需要单独定义并使用 wrap 修饰,推荐写法如下:

    from horizon_plugin_pytorch.utils.fx_helper import wrap as fx_wrap

    @fx_wrap()
    def test(self, x):
    if self.training:
    pass

    def forward(self, x):
    ···
    x = self.test(x)
    return x

2.2 数据校准

对于部分模型,仅通过 Calibration 便可使精度达到要求,不必进行比较耗时的量化感知训练。即使模型经过量化校准后无法满足精度要求,此过程也可降低后续量化感知训练的难度,缩短训练时间,提升最终的训练精度。数据校准的具体配置方式及调参建议可参考QAT方案Calibration使用说明

2.3 量化训练

量化训练一些推荐的超参配置如下表所示:

超参

推荐配置

高级配置(如果推荐配置无效请尝试)

LR

从0.001开始,搭配StepLR做2次scale=0.1的lr decay

1. 调整lr在0.0001->0.001之间,配合1-2的lr decay。-
2. LR 更新策略也可以尝试把 StepLR 替换为 CosLR。-
3. QAT使用AMP,适当调小lr,过大导致nan。

Epoch

浮点epoch的10%

1. 根据loss和metric的收敛情况,考虑是否需要适当延长epoch。

Weight decay

与浮点一致

1. 建议在4e-5附近做适当调整。weight decay过小导致weight方差过大,过大导致输出较大的任务输出层weight方差过大。

optimizer

与浮点一致

1. 如果浮点训练采用的是 OneCycle 等会影响 LR 设置的优化器,建议不要与浮点保持一致,使用 SGD 替换。

transforms(数据增强)

与浮点一致

1. QAT阶段可以适当减弱,比如分类的颜色转换可以去掉,RandomResizeCrop的比例范围可以适当缩小

averaging_constant(qconfig_params)

1. 使用calibration后推荐减弱激活更新:-
weight averaging_constant=1.0-
activation averaging_constant=0.0

1. calibration的精度和浮点差距较大时:activation averaging_constant不要设置成0.0-
2. weight averaging_constant一般不需要设置成0.0,实际情况可以在(0,1.0]之间调整

强烈建议您先尝试数据校准,若精度不满足预期再进行量化训练(注意要加载数据校准后权重参数)。-
量化训练阶段的调参建议可以参考用户手册 量化训练精度调优建议

2.4 定点转换

请注意,定点模型和伪量化模型之间无法做到完全数值一致,所以请以定点模型的精度为准。若定点精度不达标,仍需要继续进行量化训练,建议多保留几个epoch的qat模型权重,便于寻找最优的定点精度。(qat或者calibrate精度高并不一定代表定点精度高,可以考虑进行一些回退,平衡最终的定点精度)

在正常情况下,定点模型的精度与板端部署精度是可以保持完全一致的,因此可使用该模型来评测最终部署精度。

2.5 模型编译

模型编译阶段包括以下三个步骤:

script_model = torch.jit.trace(quantized_model, example_input)
check_model(script_model.cpu(), [example_input])
compile_model(script_model,[example_input],hbm="model.hbm",input_source="pyramid",opt=O3)

compile_model()更多配置项请参考 用户手册-模型编译

其中trace之后生成的script_model可以使用horizon_plugin_pytorch.jit.save接口保存后迁移到其他机器上进行推理评测,由于推理保存后的模型要求device和trace时保持一致,使用to(device)操作修改可能会出现forward报错,具体原因和推荐的解决方式可参考用户手册-量化部署 PT 模型的跨设备 Inference 说明。-
若使用了rgb/bgr格式训练模型,部署时设置input_souce为pyramid或resizer,需要在trace之前手动插入预处理节点centered_yuv2rgb和centered_yuv2bgr ,具体可参考用户手册-RGB888 数据部署

3. 常见问题

1. 为何prepare之前要使用deepcopy?-
答:prepare_qat_fx以及convert_fx接口均不支持 inplace 参数,因此这两个接口的输入和输出模型会共享几乎所有属性,因此建议使用deepcopy复制一份,确保不改变原始输入模型。若无需保留输入模型,且未使用deepcopy,请不要对输入的模型做任何修改。-

2. 为何要设置高精度输出?-
答:依据神经网络量化背景中的介绍可知乘法累加器计算得到的激活值是int32的,为了让下一层op可以继续计算,会经过requantization的操作转为int8/int16,因此若最后一层是conv/linear节点的话,建议设置高精度输出,使得模型可以以int32格式直接输出,对精度保持情况大有裨益。

此外,在prepare qat之前通过`model.classifier.qconfig = default_qat_8bit_weight_32bit_out_fake_quant_qconfig`配置高精度输出也是可以的,且该方式优先级高于prepare时通过dict传入qconfig配置。 > plugin ≤ v1.6.2 配置高精度输出需使用 default_calib_out_8bit_fake_quant_qconfig ,但该参数将在后续版本中被废弃

3. 如何理解fake quantize的几种状态?-
fake quantize 一共有三种状态,分别需要在 QAT 、 calibration 、 validation 前使用set_fake_quantize将模型的 fake quantize 设置为对应的状态。在 calibration 状态下,仅观测各算子输入输出的统计量。在 QAT 状态下,除观测统计量外还会进行伪量化操作。而在 validation 状态下,不会观测统计量,仅进行伪量化操作。

class FakeQuantState(Enum):
    QAT = "qat"
    CALIBRATION = "calibration"
    VALIDATION = "validation"