J5算法工具链int16配置方式(QAT)

1. 使用场景

在QAT过程中,默认情况下基本所有算子的计算精度都是int8,但在如下情况中我们可能需要使用int16计算:

  1. 部分算子如gridsample,其grid输入有着明确的物理含义,数值范围有可能超过int8的表示范围,必须使用int16精度计算;
  2. 在精度调优过程中通过debug工具观察到某些节点数据分布范围较大,int8表示范围不足。

2. 配置方式

a. J5计算平台仅支持activation int16计算;-

b. 后文介绍的配置方式可使得目标节点的输出精度更改为int16,相应的,其后一个节点的输入精度也会变为int16,因此请参考算子支持列表确保后一个节点支持int16输入,否则后续转换编译过程会报错;-

c. 若想修改模型输入的精度,则对quantstub节点配置int16即可。

2.1 数据校准

# default_calib_16bit_fake_quant_qconfig 的定义
# 请注意,plugin<1.6.2时 default_calib_16bit_fake_quant_qconfig 定义有误(使用了不支持的int16 weight),不可直接import后使用,需要按如下代码自行定义qconfig
from horizon_plugin_pytorch.quantization.qconfig import get_default_qconfig

default_calib_16bit_fake_quant_qconfig = get_default_qconfig(
    activation_fake_quant="fake_quant",
    weight_fake_quant="fake_quant",
    activation_observer="percentile",
    weight_observer="min_max",
    activation_qkwargs={
        "dtype": qint16,
    },
    weight_qkwargs={
        "qscheme": torch.per_channel_symmetric,
        "ch_axis": 0,
    },
)

2.1.1 Fx mode

Fx mode也支持2.1.2中的配置方式,且优先级高于通过如下dict方式来配置(也就是说如果通过dict配置int8,同时又通过xxx.qconfig = int16_qconfig的方式进行配置,则对应节点会以int16精度计算)。

from horizon_plugin_pytorch.quantization.qconfig import default_calib_16bit_fake_quant_qconfig

···
calib_model = prepare_qat_fx(
    copy.deepcopy(float_model),
    {
        "": default_calib_8bit_fake_quant_qconfig,
        "module_name": {
            # 实际使用时请替换成目标节点的名称,若名称输入错误会有日志提示找不到对应节点
            "target_op_name": default_calib_16bit_fake_quant_qconfig,
         },
     },
)

2.1.2 Eager mode

Eager mode对应的prepare_qat(接口定义可查阅用户手册api reference)接口不支持通过dict传入qconfig配置,只能通过如下方式设置qconfig。

from horizon_plugin_pytorch.quantization.qconfig import default_calib_16bit_fake_quant_qconfig

···
float_model.qconfig = default_calib_8bit_fake_quant_qconfig
# 实际使用时请替换成具体目标节点,确保通过索引方式可查找到对应节点
float_model.target_op.qconfig = default_calib_16bit_fake_quant_qconfig
qat_model = prepare_qat(float_model)

2.2 量化训练

# default_qat_16bit_fake_quant_qconfig 的定义
# 请注意,plugin<1.6.2时 default_qat_16bit_fake_quant_qconfig 定义有误(使用了不支持的int16 weight),不可直接import后使用,需要按如下代码自行定义qconfig
from horizon_plugin_pytorch.quantization.qconfig import get_default_qconfig

default_qat_16bit_fake_quant_qconfig = get_default_qconfig(
    activation_fake_quant="fake_quant",
    weight_fake_quant="fake_quant",
    activation_observer="min_max",
    weight_observer="min_max",
    activation_qkwargs={
        "dtype": qint16,
    },
    weight_qkwargs={
        "qscheme": torch.per_channel_symmetric,
        "ch_axis": 0,
    },
)

2.2.1 Fx mode

Fx mode同样也支持后文2.2.2节的配置方式,且优先级高于通过如下dict方式来配置(也就是说如果通过dict配置int8,同时又通过xxx.qconfig = int16_qconfig的方式进行配置,则对应节点会以int16精度计算)。

from horizon_plugin_pytorch.quantization.qconfig import default_qat_16bit_fake_quant_qconfig

···
qat_model = prepare_qat_fx(
    copy.deepcopy(float_model),
    {
        "": default_qat_8bit_fake_quant_qconfig,
        "module_name": {
            # 实际使用时请替换成目标节点的名称,若名称输入错误会有日志提示找不到对应节点
            "target_op": default_qat_16bit_fake_quant_qconfig,
         },
     },
)

2.2.2 Eager mode

Eager mode对应的prepare_qat(接口定义可查阅用户手册api reference)接口不支持通过dict传入qconfig配置,只能通过如下方式设置qconfig。

from horizon_plugin_pytorch.quantization.qconfig import default_qat_16bit_fake_quant_qconfig

···
float_model.qconfig = default_qat_8bit_fake_quant_qconfig
# 实际使用时请替换成具体目标节点,确保通过索引方式可查找到对应节点
float_model.target_op.qconfig = default_qat_16bit_fake_quant_qconfig
qat_model = prepare_qat(float_model)