训练部署一致性

在部署模型时,可能出现 QAT 模型精度和 HBM 模型精度不一致的问题,此章节主要介绍在 QAT 策略上如何避免这类问题,以及发生这类问题时如何定位并解决问题。

注意

训练部署无法做到比特一致,必然存在极少数 case 无法完全对齐,应当以数据集的评测精度作为判断是否存在一致性问题的依据。

训练部署一致性问题分为两类:

  1. 用户侧问题。例如:前后处理不一致,模型版本不一致等。

  2. 工具侧问题。模型在 export / convert / compile 过程中产生的不一致。

无论是哪种类型的问题,都需要您使用训练部署一致性工具排查,用户侧问题需要您自行解决,而工具侧问题需要提供 debug 产出物给地平线技术支持人员,由研发团队分析后给出解决方案。

在一致性问题的排查过程中,涉及到以下几种模型:

模型说明获取方法
qat.pttorch qat 模型。对浮点模型使用 prepare 接口。
qat.export.pttorch qat export 模型。对 qat.pt 做了非等价替换,qat.export.pt 计算逻辑与 qat.bc 完全一致。qat.pt 模型使用 pre_export 接口。
qat.bc导出产生的 hbir 模型。qat.pt 模型使用 export 接口。
quantized.bc转换产生的 hbir 模型。qat.bc 模型使用 convert 接口。
hbm编译产生的部署模型。quantized.bc 模型使用 compile 接口。

pre_export 接口用法如下:

from horizon_plugin_pytorch.quantization.hbdk4 import pre_export qat_export_pt = pre_export(qat_pt)

高一致性 QAT 策略(实验功能)

高一致性策略封装在 horizon_plugin_pytorch.qat_mode.ConsistencyStrategy 下,可以使用 set_consistency_level 接口设置策略。

当前支持五个等级( 0 - 4 )的策略,等级越高,一致性越好,但 QAT 精度可能受到轻微影响。推荐直接使用 level 2,在绝大多数情况下对 QAT 精度无影响,甚至可以改善因截断误差引起的精度问题,对性能和一致性有正收益。

对于未使用高一致性策略的 QAT 模型,如果希望不重训获得一致性更高的定点模型,可以在 prepare 模型前设置一致性策略等级为 0(不重训的情况下只有 level 0 有效,level 1 - 4 需要设置等级后重训模型)。

from horizon_plugin_pytorch.qat_mode import ConsistencyStrategy # 必须在 prepare 之前设置一致性策略 ConsistencyStrategy.set_consistency_level(2) ... qat_pt = prepare(float_model) ... qat_bc = export(qat_pt, example_inputs) # 如果设置 ConsistencyStrategy.set_consistency_level(0), 可以做如下检查 print(qat_bc._high_precision_qpp) # 值应为 true print(qat_bc._fuse_requantize) # 值应为 False quantized_bc = convert(qat_bc, march)
注意

高一致性 QAT 策略需要 hbdk 版本不低于 4.4.2,plugin 版本不低于 2.7.1。

一致性问题定位流程

一致性问题定位流程如下:

  1. 构造大数据集。要求数据集至少包含 1000 帧且评测精度较为稳定,不存在 hbm 或 bc 模型精度大幅高于 torch qat 模型精度的情况,可以复现 hbm / quantize.bc 一致性问题。

  2. 适配 bc 推理流程。

  3. 用户侧问题排查。cpu 跑 qat.export.pt 和 qat.bc 的小数据集精度,验证 bc 推理流程是否正确。

3.1 精度一致,进入工具侧问题排查。

3.2 精度不一致,关闭伪量化再验证 qat.export.pt 和 qat.bc 的小数据集精度是否一致。

from horizon_plugin_pytorch.quantization.hbdk4 import export, pre_export from horizon_plugin_pytorch.quantization import FakeQuantState qat_pt.eval() set_fake_quantize(qat_pt, FakeQuantState._FLOAT) qat_bc = export(qat_pt, example_inputs) qat_export_pt = pre_export(qat_pt)

3.2.1 精度一致,排查关闭伪量化前,伪量化和 observer 的状态是否与 validation 状态一致。

print(qat_pt) # fake_quant_enabled 应为 True, observer_enabled 应为 False GraphModuleImpl( (quant): QuantStub( (activation_post_process): FakeQuantize( dtype=qint8, fake_quant_enabled=True, observer_enabled=False, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([0.1250]), zero_point=tensor([0]) (activation_post_process): MinMaxObserver(min_val=-10.0,max_val=9.999998092651367,averaging_constant=1) ) ) ... )

3.2.2 精度不一致,qat.export.pt 和 qat.bc 分别跑统计量,逐层对比.(需要您自行保证名字对应关系并使用相同输入)

from horizon_plugin_profiler import QuantAnalysis qa = QuantAnalysis(qat_export_pt, qat_bc, "export") # torch 与 bc 可接受同一格式输入时,一起跑统计量 qa.set_bad_case(badcase) qa.run() # torch 与 bc 不可接受同一格式输入时,分开跑统计量,pt_badcase 与 bc_badcase 除格式外全部相同。 qa.set_bad_case(pt_badcase) qa.run(run_baseline_model=True, run_analysis_model=False) qa.set_bad_case(bc_badcase) qa.run(run_baseline_model=False, run_analysis_model=True) # 逐层对比 qa.compare_per_layer()
  1. 工具侧问题排查。

4.1 quantized.bc 跑大数据集精度。

4.2 如果 quantized.bc 和 qat.pt 对比,精度有一致性问题,那么 qat.export.pt 跑大数据集精度。

4.2.1 如果 qat.export.pt 和 qat.pt 对比,精度有一致性问题,那么问题发生在 export 阶段。qat.pt 和 qat.export.pt 跑一致性敏感度和逐层对比。在常规方法无法定位问题时,用 pre export 接口分段定位。

from horizon_plugin_profiler import QuantAnalysis from horizon_plugin_pytorch.quantization.hbdk4 import pre_export # qat.pt 和 qat.export.pt 跑一致性敏感度和逐层对比 qa = QuantAnalysis(qat_pt, qat_export_pt, "pre_export") qa.auto_find_bad_case(dataloader) qa.run() qa.compare_per_layer() qa.sensitivity() # pre export 接口分段 qat_pt.module_a = pre_export(qat_pt.module_a)

4.2.2 如果 qat.export.pt 和 qat.pt 对比,精度没有一致性问题,那么问题发生在 convert 阶段。qat.bc 和 quantized.bc 跑查找 badcase + 逐层对比,qat.export.pt 跑一致性敏感度(复用 qat.bc 和 quantized.bc 对比找出来的 badcase)。

from horizon_plugin_profiler import QuantAnalysis # qat.bc 和 quantized.bc 跑查找 badcase + 逐层对比 qa = QuantAnalysis(qat_bc, quantized_bc, "convert") qa.auto_find_bad_case(dataloader) qa.run() qa.compare_per_layer() # qat.export.pt 跑一致性敏感度 qa = QuantAnalysis(qat_export_pt, quantized_bc, "convert") qa.load_bad_case() qa.sensitivity()

在常规方法无法定位问题时,使用 bc 编辑工具将 quantized.bc 分段转 cpu 来定位问题。bc 编辑工具路径为 horizon_plugin_profiler/bc_editor/bc_editor.py,同一目录下有配置示例。

# 查看 qat bc 模型文本 print(qat_bc.module.get_asm(enable_debug_info=True)) # “%” 后面的数字为 hbir 算子编号 module attributes {hbdk.legacy_round = true} { func.func @bev_gkt_mixvargenet_multitask_nuscenes(%arg0: tensor<6x3x512x960... %0 = "qnt.const_fake_quant"(%arg0) <{bits = 8 : i64, illegal = true, max... %1 = "hbir.constant"() <{values = dense<"0xC27B5D3DFF6DE33C1822093DDA9642... %2 = "qnt.const_fake_quant"(%1) <{axis = 0 : i64, bits = 8 : i64, illegal... ...... # config.json { "remove_fake_quant": [[1, 100], 102] # 删除 hbir 中编号 1~100 和 102 的伪量化 } # 编辑后得到 qat_modified.bc,再 convert 可以得到部分算子退回 cpu 的 quantized.bc python3 bc_editor.py --bc_path qat.bc --config_path config.json --new_bc_path qat_modified.bc

对比编辑前后 quantized.bc 和 qat.bc 的差异,可以看出哪些算子转 cpu 能明显提高一致性。