1. 使用场景
在QAT过程中,默认情况下基本所有算子的计算精度都是int8,但在如下情况中我们可能需要使用int16计算:
- 部分算子如gridsample,其grid输入有着明确的物理含义,数值范围有可能超过int8的表示范围,必须使用int16精度计算;
- 在精度调优过程中通过debug工具观察到某些节点数据分布范围较大,int8表示范围不足。
2. 配置方式
a. J5计算平台仅支持activation int16计算;-
b. 后文介绍的配置方式可使得目标节点的输出精度更改为int16,相应的,其后一个节点的输入精度也会变为int16,因此请参考算子支持列表确保后一个节点支持int16输入,否则后续转换编译过程会报错;-
c. 若想修改模型输入的精度,则对quantstub节点配置int16即可。
2.1 数据校准
# default_calib_16bit_fake_quant_qconfig 的定义
# 请注意,plugin<1.6.2时 default_calib_16bit_fake_quant_qconfig 定义有误(使用了不支持的int16 weight),不可直接import后使用,需要按如下代码自行定义qconfig
from horizon_plugin_pytorch.quantization.qconfig import get_default_qconfig
default_calib_16bit_fake_quant_qconfig = get_default_qconfig(
activation_fake_quant="fake_quant",
weight_fake_quant="fake_quant",
activation_observer="percentile",
weight_observer="min_max",
activation_qkwargs={
"dtype": qint16,
},
weight_qkwargs={
"qscheme": torch.per_channel_symmetric,
"ch_axis": 0,
},
)
2.1.1 Fx mode
Fx mode也支持2.1.2中的配置方式,且优先级高于通过如下dict方式来配置(也就是说如果通过dict配置int8,同时又通过xxx.qconfig = int16_qconfig的方式进行配置,则对应节点会以int16精度计算)。
from horizon_plugin_pytorch.quantization.qconfig import default_calib_16bit_fake_quant_qconfig
···
calib_model = prepare_qat_fx(
copy.deepcopy(float_model),
{
"": default_calib_8bit_fake_quant_qconfig,
"module_name": {
# 实际使用时请替换成目标节点的名称,若名称输入错误会有日志提示找不到对应节点
"target_op_name": default_calib_16bit_fake_quant_qconfig,
},
},
)
2.1.2 Eager mode
Eager mode对应的prepare_qat(接口定义可查阅用户手册api reference)接口不支持通过dict传入qconfig配置,只能通过如下方式设置qconfig。
from horizon_plugin_pytorch.quantization.qconfig import default_calib_16bit_fake_quant_qconfig
···
float_model.qconfig = default_calib_8bit_fake_quant_qconfig
# 实际使用时请替换成具体目标节点,确保通过索引方式可查找到对应节点
float_model.target_op.qconfig = default_calib_16bit_fake_quant_qconfig
qat_model = prepare_qat(float_model)
2.2 量化训练
# default_qat_16bit_fake_quant_qconfig 的定义
# 请注意,plugin<1.6.2时 default_qat_16bit_fake_quant_qconfig 定义有误(使用了不支持的int16 weight),不可直接import后使用,需要按如下代码自行定义qconfig
from horizon_plugin_pytorch.quantization.qconfig import get_default_qconfig
default_qat_16bit_fake_quant_qconfig = get_default_qconfig(
activation_fake_quant="fake_quant",
weight_fake_quant="fake_quant",
activation_observer="min_max",
weight_observer="min_max",
activation_qkwargs={
"dtype": qint16,
},
weight_qkwargs={
"qscheme": torch.per_channel_symmetric,
"ch_axis": 0,
},
)
2.2.1 Fx mode
Fx mode同样也支持后文2.2.2节的配置方式,且优先级高于通过如下dict方式来配置(也就是说如果通过dict配置int8,同时又通过xxx.qconfig = int16_qconfig的方式进行配置,则对应节点会以int16精度计算)。
from horizon_plugin_pytorch.quantization.qconfig import default_qat_16bit_fake_quant_qconfig
···
qat_model = prepare_qat_fx(
copy.deepcopy(float_model),
{
"": default_qat_8bit_fake_quant_qconfig,
"module_name": {
# 实际使用时请替换成目标节点的名称,若名称输入错误会有日志提示找不到对应节点
"target_op": default_qat_16bit_fake_quant_qconfig,
},
},
)
2.2.2 Eager mode
Eager mode对应的prepare_qat(接口定义可查阅用户手册api reference)接口不支持通过dict传入qconfig配置,只能通过如下方式设置qconfig。
from horizon_plugin_pytorch.quantization.qconfig import default_qat_16bit_fake_quant_qconfig
···
float_model.qconfig = default_qat_8bit_fake_quant_qconfig
# 实际使用时请替换成具体目标节点,确保通过索引方式可查找到对应节点
float_model.target_op.qconfig = default_qat_16bit_fake_quant_qconfig
qat_model = prepare_qat(float_model)