prepare 详解

prepare 的定义

prepare 是将浮点模型转换为伪量化模型的过程。这个过程会做以下几件事情:

  1. 算子替换:部分 torch function 类型的算子(例如 F.interpolate)在量化时需要插入伪量化节点,因此需要将算子替换为对应的 Module 类型实现(horizon_plugin_pytorch.nn.Interpolate),以将伪量化节点放在此 Module 内部。替换前后的模型是等价的。

  2. 算子融合:BPU 支持将特定的计算 pattern 进行融合,融合后算子中间结果用高精度表示,因此我们将被融合的多个算子替换为一个 Module,以阻止中间结果的量化。融合前后的模型也是等价的。

  3. 算子转换:将浮点算子替换为 qat 算子。按照设置的 qconfig,qat 算子会在输入/输出/权重处添加伪量化/伪转换节点。

注意

请确保 prepare 之后不会再修改模型,否则已经被替换的 qat 算子可能产生不符合预期的行为。例如:prepare 之后再将未融合的 bn 转为 sync bn 可能导致 qat bn 被再次修改为 sync bn,应该在prepare之前将它转为 sync bn。

  1. 模型结构检查:检查 qat 模型,生成检查结果文件。
注解

精度说明

prepare 后的模型在计算逻辑上和浮点模型存在以下区别:

  1. 模型中加入了伪量化节点。

  2. 极少数输出存在极大值的算子(如 reciprocal),为了适配量化,会默认将输出 clip 至合理范围内。

以上操作会导致模型数值发生变化。

prepare 接口的用法如下:

from horizon_plugin_pytorch.quantization.prepare import prepare, PrepareMethod from horizon_plugin_pytorch.quantization.qconfig_template import ( default_qat_qconfig_setter, sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter, ) # 使用模板时必须提供 example_inputs 和 qconfig_setter。 # method 为 PrepareMethod.JIT_STRIP 或 PrepareMethod.JIT 时,必须提供 example_inputs。 # def prepare( # model: torch.nn.Module, # example_inputs: Any = None, # 用来感知图结构,确保可以用来跑通 forward。 # qconfig_setter: Optional[Union[Tuple[QconfigSetterBase, ...], QconfigSetterBase]] = None, # qconfig 模板,支持传入多个模板,优先级从高到低。 # method: PrepareMethod = PrepareMethod.JIT_STRIP, # prepare 模式 # ) -> torch.nn.Module: qat_model = prepare( float_model, example_inputs=example_inputs, qconfig_setter=( sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter(table, ratio=0.2), default_qat_qconfig_setter, ), method=PrepareMethod.JIT, )

PrepareMethod

prepare method包括JIT_STRIP和EAGER,JIT_STRIP属于Graph Mode,EAGER则属于PrepareMethod.EAGER,他们的对比如下:

method原理优点缺点
PrepareMethod.JIT_STRIP(Graph Mode)使用 hook 和 subclass tensor 的方式感知图结构,在原有 forward 上做算子替换/算子融合等操作。全自动,代码修改少,屏蔽了很多细节问题,便于 debug。动态代码块需要特殊处理。
PrepareMethod.EAGER(Eager Mode)不感知图结构,算子替换/算子融合需手动进行。用法灵活,过程可控,便于 debug 和处理各类特殊需求。手动操作较多,代码修改多,上手成本高。

目前,JIT_STRIP 为我们推荐的 method,JIT_STRIP 会根据模型中 QuantStub 和 DequantStub 的位置识别并跳过前后处理。

使用示例

import copy import numpy as np import torch from torch import nn from torch.nn import functional as F from torch.quantization import DeQuantStub, QuantStub from horizon_plugin_pytorch import March, set_march from horizon_plugin_pytorch.fx.jit_scheme import Tracer from horizon_plugin_pytorch.quantization import ( FakeQuantState, get_qconfig, PrepareMethod, prepare, set_fake_quantize, ) class Net(torch.nn.Module): def __init__(self, input_size, class_num) -> None: super().__init__() self.quant0 = QuantStub() self.quant1 = QuantStub() self.dequant = DeQuantStub() self.conv = nn.Conv2d(3, 3, 1) self.bn = nn.BatchNorm2d(3) self.classifier = nn.Conv2d(3, class_num, input_size) self.loss = nn.CrossEntropyLoss() def forward(self, input, other, target=None): # 不需要量化的前处理,使用 JIT_STRIP 时,将这些操作从计算图中剔除。 input = (input - 128) / 128.0 x = self.quant0(input) y = self.quant1(other) n = np.random.randint(1, 5) m = np.random.randint(1, 5) # 由于不重新生成 python code,此动态循环在 QAT 模型中保留。 for _ in range(n): for _ in range(m): # 动态循环中的代码块涉及到算子替换或算子融合时,必须进行标注。 # 标注的是需要算子替换或算子融合的逻辑,而不是 for 循环。 with Tracer.dynamic_block(self, "ConvBnAdd"): x = self.conv(x) x = self.bn(x) x = x + y x = self.classifier(x).squeeze() # 由于不重新生成 python code,此动态控制流在 QAT 模型中保留 if self.training: assert target is not None x = self.dequant(x) return F.cross_entropy(torch.softmax(x, dim=1), target) else: return torch.argmax(x, dim=1) model = Net(6, 2) train_example_input = ( torch.rand(2, 3, 6, 6) * 256, torch.rand(2, 3, 6, 6), torch.tensor([[0.0, 1.0], [1.0, 0.0]]), ) eval_example_input = train_example_input[:2] model.eval() set_march("nash-e") model.qconfig = get_qconfig() qat_model = prepare( model, example_inputs=copy.deepcopy(eval_example_input), method=PrepareMethod.JIT_STRIP, ) qat_model.graph.print_tabular() # opcode name target args kwargs # ------------- ---------------- --------------------------------------------------------- -------------------------------- ---------- # placeholder input_0 input_0 () {} # call_module quant0 quant0 (input_0,) {} # placeholder input_1 input_1 () {} # call_module quant1 quant1 (input_1,) {} # call_module conv conv (quant0,) {} # call_module bn bn (conv,) {} # get_attr _generated_add_0 _generated_add_0 () {} # call_method add_2 add (_generated_add_0, bn, quant1) {} # scope_end 是在 trace 过程中自动插入的,用于标记子 module 或动态代码块的边界,不对应实际计算 # call_function scope_end <function Tracer.scope_end at 0x7f65d90e5e50> ('_dynamic_block_ConvBnAdd',) {} # call_module conv_1 conv (add_2,) {} # call_module bn_1 bn (conv_1,) {} # get_attr _generated_add_1 _generated_add_0 () {} # call_method add_3 add (_generated_add_1, bn_1, quant1) {} # call_function scope_end_1 <function Tracer.scope_end at 0x7f65d90e5e50> ('_dynamic_block_ConvBnAdd',) {} # call_module classifier classifier (add_3,) {} # call_function squeeze <method 'squeeze' of 'torch._C._TensorBase' objects> (classifier,) {} # call_function argmax <built-in method argmax of type object at 0x7f66f04cf820> (squeeze,) {'dim': 1} # call_function scope_end_2 <function Tracer.scope_end at 0x7f65d90e5e50> ('',) {} # output output output ((argmax,),) {} print(qat_model) # GraphModuleImpl( # (quant0): QuantStub( # (activation_post_process): FakeQuantize( # fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0]) # (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([])) # ) # ) # (quant1): QuantStub( # (activation_post_process): FakeQuantize( # fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0]) # (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([])) # ) # ) # (dequant): DeQuantStub() # (conv): Identity() # 由于 forward 代码不变,conv 和 bn 仍将被执行,所以融合后必须将 Conv 和 Bn 替换为 Identity # (bn): Identity() # (classifier): Conv2d( # 3, 2, kernel_size=(6, 6), stride=(1, 1) # (activation_post_process): FakeQuantize( # fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0]) # (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([])) # ) # (weight_fake_quant): FakeQuantize( # fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, scale=tensor([1., 1.]), zero_point=tensor([0, 0]) # (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([])) # ) # ) # (loss): CrossEntropyLoss() # (_generated_add_0): ConvAdd2d( # 自动将 '+' 替换为 Module 形式,并将 Conv 和 Bn 融合进来 # 3, 3, kernel_size=(1, 1), stride=(1, 1) # (activation_post_process): FakeQuantize( # fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0]) # (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([])) # ) # (weight_fake_quant): FakeQuantize( # fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, scale=tensor([1., 1., 1.]), zero_point=tensor([0, 0, 0]) # (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([])) # ) # ) # ) qat_model.train() set_fake_quantize(qat_model, FakeQuantState.QAT) for _ in range(3): ret = qat_model(*train_example_input) ret.backward()
注意
  1. 动态代码块涉及到算子替换或算子融合时,必须使用 Tracer.dynamic_block 进行标注,否则将导致量化信息错乱或 forward 报错。
  2. 模型中调用次数变化的部分(子 module 或 dynamic_block),若在 trace 时仅执行了一次,则有可能和非动态部分产生算子融合,导致 forward 报错。

使用基于图的 prepare 方法时,最好保证 model 只包含部署逻辑。

# 执行分支受 training 或其他开关控制,和最终的部署状态可能不一致,quant 和 dequant 位置容易加错。 def forward(self, input, gt): conv_out = self.conv(input) if self.training: return self.loss(conv_out, gt) else: return self.sigmoid(conv_out) # 剥离出部署逻辑。只针对这部分逻辑 prepare。 def forward_infer(self, input, gt): conv_out = self.conv(input) return self.sigmoid(conv_out), conv_out # 非部署逻辑放在外面不参与 prepare。 def forward(self, input, gt): sig_out, conv_out = self.forward_infer(input, gt) if self.training: return self.loss(conv_out, gt) else: return sig_out

如果代码因可读性 / 可维护性等原因无法剥离出干净的部署逻辑,那么需要做以下检查:

  1. 加载 ckpt 时是否存在 missing key 和 unexpected key。缺少或多出量化参数可能意味着某些部署逻辑没有被量化或某些非部署逻辑被量化了,缺少或多出模型参数可能意味着此次 prepare 的 forward 逻辑和产生 ckpt 的 forward 逻辑未对齐。

  2. 多次 prepare 产生的 fx graph 是否一致。fx graph 不一致意味着多次 prepare 的 forward 逻辑不一致,需要检查是否符合预期。

模型检查

在提供 example_inputs 的情况下,prepare 默认会对模型结构进行检查。如果检查完成,可以在运行目录下找到 model_check_result.txt 文件,如果检查失败,则需要根据警告提示修改模型或单独调用 horizon_plugin_pytorch.utils.check_model.check_qat_model 检查模型。检查流程和 debug 工具中的 check_qat_model 一致,结果文件的分析详见 check_qat_model 相关文档。