训练部署一致性
在部署模型时,可能出现 QAT 模型精度和 HBM 模型精度不一致的问题,此章节主要介绍在 QAT 策略上如何避免这类问题,以及发生这类问题时如何定位并解决问题。
注意
训练部署无法做到比特一致,必然存在极少数 case 无法完全对齐,应当以数据集的评测精度作为判断是否存在一致性问题的依据。
训练部署一致性问题分为两类:
-
用户侧问题。例如:前后处理不一致,模型版本不一致等。
-
工具侧问题。模型在 export / convert / compile 过程中产生的不一致。
无论是哪种类型的问题,都需要您使用训练部署一致性工具排查,用户侧问题需要您自行解决,而工具侧问题需要提供 debug 产出物给地平线技术支持人员,由研发团队分析后给出解决方案。
在一致性问题的排查过程中,涉及到以下几种模型:
| 模型 | 说明 | 获取方法 |
|---|
| qat.pt | torch qat 模型。 | 对浮点模型使用 prepare 接口。 |
| qat.export.pt | torch qat export 模型。对 qat.pt 做了非等价替换,qat.export.pt 计算逻辑与 qat.bc 完全一致。 | 对 qat.pt 模型使用 pre_export 接口。 |
| qat.bc | 导出产生的 hbir 模型。 | 对 qat.pt 模型使用 export 接口。 |
| quantized.bc | 转换产生的 hbir 模型。 | 对 qat.bc 模型使用 convert 接口。 |
| hbm | 编译产生的部署模型。 | 对 quantized.bc 模型使用 compile 接口。 |
pre_export 接口用法如下:
from horizon_plugin_pytorch.quantization.hbdk4 import pre_export
qat_export_pt = pre_export(qat_pt)
高一致性 QAT 策略(实验功能)
高一致性策略封装在 horizon_plugin_pytorch.qat_mode.ConsistencyStrategy 下,可以使用 set_consistency_level 接口设置策略。
当前支持五个等级( 0 - 4 )的策略,等级越高,一致性越好,但 QAT 精度可能受到轻微影响。推荐直接使用 level 2,在绝大多数情况下对 QAT 精度无影响,甚至可以改善因截断误差引起的精度问题,对性能和一致性有正收益。
对于未使用高一致性策略的 QAT 模型,如果希望不重训获得一致性更高的定点模型,可以在 prepare 模型前设置一致性策略等级为 0(不重训的情况下只有 level 0 有效,level 1 - 4 需要设置等级后重训模型)。
from horizon_plugin_pytorch.qat_mode import ConsistencyStrategy
# 必须在 prepare 之前设置一致性策略
ConsistencyStrategy.set_consistency_level(2)
...
qat_pt = prepare(float_model)
...
qat_bc = export(qat_pt, example_inputs)
# 如果设置 ConsistencyStrategy.set_consistency_level(0), 可以做如下检查
print(qat_bc._high_precision_qpp) # 值应为 true
print(qat_bc._fuse_requantize) # 值应为 False
quantized_bc = convert(qat_bc, march)
注意
高一致性 QAT 策略需要 hbdk 版本不低于 4.4.2,plugin 版本不低于 2.7.1。
一致性问题定位流程

一致性问题定位流程如下:
-
构造大数据集。要求数据集至少包含 1000 帧且评测精度较为稳定,不存在 hbm 或 bc 模型精度大幅高于 torch qat 模型精度的情况,可以复现 hbm / quantize.bc 一致性问题。
-
适配 bc 推理流程。
-
用户侧问题排查。cpu 跑 qat.export.pt 和 qat.bc 的小数据集精度,验证 bc 推理流程是否正确。
3.1 精度一致,进入工具侧问题排查。
3.2 精度不一致,关闭伪量化再验证 qat.export.pt 和 qat.bc 的小数据集精度是否一致。
from horizon_plugin_pytorch.quantization.hbdk4 import export, pre_export
from horizon_plugin_pytorch.quantization import FakeQuantState
qat_pt.eval()
set_fake_quantize(qat_pt, FakeQuantState._FLOAT)
qat_bc = export(qat_pt, example_inputs)
qat_export_pt = pre_export(qat_pt)
3.2.1 精度一致,排查关闭伪量化前,伪量化和 observer 的状态是否与 validation 状态一致。
print(qat_pt)
# fake_quant_enabled 应为 True, observer_enabled 应为 False
GraphModuleImpl(
(quant): QuantStub(
(activation_post_process): FakeQuantize(
dtype=qint8, fake_quant_enabled=True, observer_enabled=False, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([0.1250]), zero_point=tensor([0])
(activation_post_process): MinMaxObserver(min_val=-10.0,max_val=9.999998092651367,averaging_constant=1)
)
)
...
)
3.2.2 精度不一致,qat.export.pt 和 qat.bc 分别跑统计量,逐层对比.(需要您自行保证名字对应关系并使用相同输入)
from horizon_plugin_profiler import QuantAnalysis
qa = QuantAnalysis(qat_export_pt, qat_bc, "export")
# torch 与 bc 可接受同一格式输入时,一起跑统计量
qa.set_bad_case(badcase)
qa.run()
# torch 与 bc 不可接受同一格式输入时,分开跑统计量,pt_badcase 与 bc_badcase 除格式外全部相同。
qa.set_bad_case(pt_badcase)
qa.run(run_baseline_model=True, run_analysis_model=False)
qa.set_bad_case(bc_badcase)
qa.run(run_baseline_model=False, run_analysis_model=True)
# 逐层对比
qa.compare_per_layer()
- 工具侧问题排查。
4.1 quantized.bc 跑大数据集精度。
4.2 如果 quantized.bc 和 qat.pt 对比,精度有一致性问题,那么 qat.export.pt 跑大数据集精度。
4.2.1 如果 qat.export.pt 和 qat.pt 对比,精度有一致性问题,那么问题发生在 export 阶段。qat.pt 和 qat.export.pt 跑一致性敏感度和逐层对比。在常规方法无法定位问题时,用 pre export 接口分段定位。
from horizon_plugin_profiler import QuantAnalysis
from horizon_plugin_pytorch.quantization.hbdk4 import pre_export
# qat.pt 和 qat.export.pt 跑一致性敏感度和逐层对比
qa = QuantAnalysis(qat_pt, qat_export_pt, "pre_export")
qa.auto_find_bad_case(dataloader)
qa.run()
qa.compare_per_layer()
qa.sensitivity()
# pre export 接口分段
qat_pt.module_a = pre_export(qat_pt.module_a)
4.2.2 如果 qat.export.pt 和 qat.pt 对比,精度没有一致性问题,那么问题发生在 convert 阶段。qat.bc 和 quantized.bc 跑查找 badcase + 逐层对比,qat.export.pt 跑一致性敏感度(复用 qat.bc 和 quantized.bc 对比找出来的 badcase)。
from horizon_plugin_profiler import QuantAnalysis
# qat.bc 和 quantized.bc 跑查找 badcase + 逐层对比
qa = QuantAnalysis(qat_bc, quantized_bc, "convert")
qa.auto_find_bad_case(dataloader)
qa.run()
qa.compare_per_layer()
# qat.export.pt 跑一致性敏感度
qa = QuantAnalysis(qat_export_pt, quantized_bc, "convert")
qa.load_bad_case()
qa.sensitivity()
在常规方法无法定位问题时,使用 bc 编辑工具将 quantized.bc 分段转 cpu 来定位问题。bc 编辑工具路径为 horizon_plugin_profiler/bc_editor/bc_editor.py,同一目录下有配置示例。
# 查看 qat bc 模型文本
print(qat_bc.module.get_asm(enable_debug_info=True))
# “%” 后面的数字为 hbir 算子编号
module attributes {hbdk.legacy_round = true} {
func.func @bev_gkt_mixvargenet_multitask_nuscenes(%arg0: tensor<6x3x512x960...
%0 = "qnt.const_fake_quant"(%arg0) <{bits = 8 : i64, illegal = true, max...
%1 = "hbir.constant"() <{values = dense<"0xC27B5D3DFF6DE33C1822093DDA9642...
%2 = "qnt.const_fake_quant"(%1) <{axis = 0 : i64, bits = 8 : i64, illegal...
......
# config.json
{
"remove_fake_quant": [[1, 100], 102] # 删除 hbir 中编号 1~100 和 102 的伪量化
}
# 编辑后得到 qat_modified.bc,再 convert 可以得到部分算子退回 cpu 的 quantized.bc
python3 bc_editor.py --bc_path qat.bc --config_path config.json --new_bc_path qat_modified.bc
对比编辑前后 quantized.bc 和 qat.bc 的差异,可以看出哪些算子转 cpu 能明显提高一致性。