torch.nn.Parameter如何量化

1.芯片型号:J5

2.天工开物开发包OpenExplorer版本:OE 1.1.29

3.问题定位:QAT训练

4.问题具体描述:你好,我当前定义了一个新的层,其中包含了一个nn.Parameter类型的变量

然后这个类的前推如下,请问edge_weights这个nn.Parameter类型的变量要如何量化?

您好~建议可以改写为:quant_edge_weights = self.quant(self.edge_weights)

这个self.edge_weights是一个可学习的权重。那么self.quant需要如何定义scale?是self.quant=QuantStub(scale=None), 还是self.quant=QuantStub(scale=1.0/128.0)

self.quant=QuantStub() 即可,根据输入动态统计

你好,我现在的forward调用了torch.sum,在QAT训练时是否要替换?

您好~torch.sum是可以被支持的(4.2.7.2. 支持的公版算子 — Horizon Open Explorer),eager mode需要替换(包括*、/、+也需要进行替换),fx mode不需要

你好,现在我在编译hbm时报错了。forward如下

出错在197行

报错信息:

File “/open_explorer/pyproject/bevod_v19/horizon_bev/hat/models/necks/bifpn.py”, line 197, in forward

quant_weights_sum=self.sum.sum(quant_relu_edge_weights,dim=0,keepdim=True)

File “/usr/local/lib64/python3.6/site-packages/horizon_plugin_pytorch/utils/model_helper.py”, line 52, in _call_impl

result = func(mod, *input, **kwargs)

File “/usr/local/lib64/python3.6/site-packages/horizon_plugin_pytorch/nn/quantized/functional_modules.py”, line 411, in sum

self.out_dtype,

File “/usr/local/lib64/python3.6/site-packages/horizon_plugin_pytorch/march.py”, line 71, in wrapped_func

return func(*args, **kwargs, march=get_march())

File “/usr/local/lib64/python3.6/site-packages/horizon_plugin_pytorch/utils/script_quantized_fn.py”, line 194, in wrapper

return fn(*args, **kwargs)

File “/usr/local/lib64/python3.6/site-packages/horizon_plugin_pytorch/nn/quantized/functional.py”, line 1198, in sum

march,

File “/usr/local/lib64/python3.6/site-packages/horizon_plugin_pytorch/nn/quantized/functional_impl.py”, line 1578, in _sum

r, x_scale, x_zero_point, “qint32”, scale, zero_point, dtype, march

File “/usr/local/lib64/python3.6/site-packages/horizon_plugin_pytorch/nn/quantized/functional_impl.py”, line 89, in _requantize

march,

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

请问您是对self.quant 配置了 per channel 量化吗?

我只做了这些设置。

请问没有对self.quant和self.edge_weight做额外配置吗?

这个配置是对整个模型设置int16吗?

在初始化中做了如下设置:

请问197行sum输入的shape是多少呀?