Prepare in Detail

Definition of Prepare

Prepare is the process of converting a floating-point model into a pseudo-quantized model. This process involves several key steps:

  1. Operator Replacement: Some torch function operators (such as F.interpolate) need to have FakeQuantize nodes inserted during quantization. Therefore, these operators are replaced with corresponding Module type implementations (horizon_plugin_pytorch.nn.Interpolate) to place the FakeQuantize nodes inside this Module. The model before and after replacement is equivalent.

  2. Operator Fusion: BPU supports fusing specific computational patterns, where the intermediate results of fused operators are represented with high precision. Therefore, we replace multiple operators to be fused with a single Module to prevent quantizing the intermediate results. The model before and after fusion is also equivalent.

  3. Operator Conversion: Floating-point operators are replaced with QAT (Quantized Awareness Training) operators. According to the configured qconfig, QAT operators will add FakeQuantize nodes at the input/output/weights.

Attention

To keep converted QAT operators work as expected, please ensure that no further modifications are made to the model after calling prepare. For example, converting an unfused BN to a sync BN after prepare may cause the QAT BN to be modified again. The conversion to sync BN should be done before calling prepare.

  1. Model Structure Check: The QAT model is checked, and a check result file is generated.
Note

Accuracy Description

The model after prepare differs from the floating-point model in terms of computational logic in the following aspects:

Pseudo-quantization nodes are added to the model. For a very small number of operators (such as reciprocal) whose outputs may have extremely large values, to adapt to quantization, their outputs will be clipped to a reasonable range by default.

The above operations will cause changes in the numerical values of the model.

The usage of the prepare interface is as follows:

from horizon_plugin_pytorch.quantization.prepare import prepare, PrepareMethod from horizon_plugin_pytorch.quantization.qconfig_template import ( default_qat_qconfig_setter, sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter, ) # When using templates, example_inputs and qconfig_setter must be provided. # When method is PrepareMethod.JIT_STRIP or PrepareMethod.JIT, example_inputs must be provided. # def prepare( # model: torch.nn.Module, # example_inputs: Any = None, # used to get model's graph structure, ensuring it can be used to run forward. # qconfig_setter: Optional[Union[Tuple[QconfigSetterBase, ...], QconfigSetterBase]] = None, # qconfig template, supports multiple templates, priority from high to low. # method: PrepareMethod = PrepareMethod.JIT_STRIP, # prepare method # ) -> torch.nn.Module: qat_model = prepare( float_model, example_inputs=example_inputs, qconfig_setter=( sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter(table, ratio=0.2), default_qat_qconfig_setter, ), method=PrepareMethod.JIT, )

PrepareMethod

The prepare method includes JIT_STRIP and EAGER. JIT_STRIP belongs to Graph Mode, while EAGER belongs to PrepareMethod.EAGER. Their comparison is as follows:

methodPrincipleAdvantagesDisadvantages
PrepareMethod.JIT_STRIP(Graph Mode)Use hooks and subclass tensor to get the graph structure, performing operator replacement/operator fusion on the original forward.Fully automatic, minimal code modification, hides many detail issues, easy to debug.Dynamic code blocks need special handling.
PrepareMethod.EAGER(Eager Mode)Does not sense the graph structure. operator replacement/operator fusion needs to be done manually.Flexible usage, controllable process, easy to debug and handle various special needs.Requires more manual operations, more code modifications, high learning cost.

Currently, JIT_STRIP is our recommended method. The JIT_STRIP will identify and skip pre-process and post-process based on the positions of QuantStub and DequantStub in the model.

Example

import copy import numpy as np import torch from torch import nn from torch.nn import functional as F from torch.quantization import DeQuantStub, QuantStub from horizon_plugin_pytorch import March, set_march from horizon_plugin_pytorch.fx.jit_scheme import Tracer from horizon_plugin_pytorch.quantization import ( FakeQuantState, get_qconfig, PrepareMethod, prepare, set_fake_quantize, ) class Net(torch.nn.Module): def __init__(self, input_size, class_num) -> None: super().__init__() self.quant0 = QuantStub() self.quant1 = QuantStub() self.dequant = DeQuantStub() self.conv = nn.Conv2d(3, 3, 1) self.bn = nn.BatchNorm2d(3) self.classifier = nn.Conv2d(3, class_num, input_size) self.loss = nn.CrossEntropyLoss() def forward(self, input, other, target=None): # Preprocess that does not need quantization. Use JIT_STRIP to exclude these operations from the computational graph. input = (input - 128) / 128.0 x = self.quant0(input) y = self.quant1(other) n = np.random.randint(1, 5) m = np.random.randint(1, 5) # Since the python code is not regenerated, this dynamic loop is retained in the QAT model. for _ in range(n): for _ in range(m): # Dynamic code blocks involving operator replacement or fusion must be marked. # The marked part refers to the logic that requires operator replacement or fusion, rather than the for loop. with Tracer.dynamic_block(self, "ConvBnAdd"): x = self.conv(x) x = self.bn(x) x = x + y x = self.classifier(x).squeeze() # Since the python code is not regenerated, this dynamic control flow is retained in the QAT model if self.training: assert target is not None x = self.dequant(x) return F.cross_entropy(torch.softmax(x, dim=1), target) else: return torch.argmax(x, dim=1) model = Net(6, 2) train_example_input = ( torch.rand(2, 3, 6, 6) * 256, torch.rand(2, 3, 6, 6), torch.tensor([[0.0, 1.0], [1.0, 0.0]]), ) eval_example_input = train_example_input[:2] model.eval() set_march("nash-e") model.qconfig = get_qconfig() qat_model = prepare( model, example_inputs=copy.deepcopy(eval_example_input), method=PrepareMethod.JIT_STRIP, ) qat_model.graph.print_tabular() # opcode name target args kwargs # ------------- ---------------- --------------------------------------------------------- -------------------------------- ---------- # placeholder input_0 input_0 () {} # call_module quant0 quant0 (input_0,) {} # placeholder input_1 input_1 () {} # call_module quant1 quant1 (input_1,) {} # call_module conv conv (quant0,) {} # call_module bn bn (conv,) {} # get_attr _generated_add_0 _generated_add_0 () {} # call_method add_2 add (_generated_add_0, bn, quant1) {} # scope_end is automatically inserted during the trace process to mark the boundaries of sub-modules or dynamic code blocks, not corresponding to any calculations # call_function scope_end <function Tracer.scope_end at 0x7f65d90e5e50> ('_dynamic_block_ConvBnAdd',) {} # call_module conv_1 conv (add_2,) {} # call_module bn_1 bn (conv_1,) {} # get_attr _generated_add_1 _generated_add_0 () {} # call_method add_3 add (_generated_add_1, bn_1, quant1) {} # call_function scope_end_1 <function Tracer.scope_end at 0x7f65d90e5e50> ('_dynamic_block_ConvBnAdd',) {} # call_module classifier classifier (add_3,) {} # call_function squeeze <method 'squeeze' of 'torch._C._TensorBase' objects> (classifier,) {} # call_function argmax <built-in method argmax of type object at 0x7f66f04cf820> (squeeze,) {'dim': 1} # call_function scope_end_2 <function Tracer.scope_end at 0x7f65d90e5e50> ('',) {} # output output output ((argmax,),) {} print(qat_model) # GraphModuleImpl( # (quant0): QuantStub( # (activation_post_process): FakeQuantize( # fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0]) # (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([])) # ) # ) # (quant1): QuantStub( # (activation_post_process): FakeQuantize( # fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0]) # (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([])) # ) # ) # (dequant): DeQuantStub() # (conv): Identity() # Since the forward code remains unchanged, conv and bn will still be executed, so after fusion, Conv and Bn must be replaced with Identity # (bn): Identity() # (classifier): Conv2d( # 3, 2, kernel_size=(6, 6), stride=(1, 1) # (activation_post_process): FakeQuantize( # fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0]) # (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([])) # ) # (weight_fake_quant): FakeQuantize( # fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, scale=tensor([1., 1.]), zero_point=tensor([0, 0]) # (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([])) # ) # ) # (loss): CrossEntropyLoss() # (_generated_add_0): ConvAdd2d( # Automatically replace '+' with Module form, and fuse Conv and Bn into it # 3, 3, kernel_size=(1, 1), stride=(1, 1) # (activation_post_process): FakeQuantize( # fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0]) # (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([])) # ) # (weight_fake_quant): FakeQuantize( # fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, scale=tensor([1., 1., 1.]), zero_point=tensor([0, 0, 0]) # (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([])) # ) # ) # ) qat_model.train() set_fake_quantize(qat_model, FakeQuantState.QAT) for _ in range(3): ret = qat_model(*train_example_input) ret.backward()
Attention
  1. When dynamic code blocks involve operator replacement or fusion, they must be marked with Tracer.dynamic_block. Otherwise, it will lead to quantization information confusion or forward errors.
  2. Parts of the model where the call count changes (sub-modules or dynamic blocks), if only executed once during the trace, may get fused with non-dynamic parts, leading to forward errors.

When using the graph-based prepare method, it's recommended that the model only contains deployment logic.

# The execution path is controlled by `training` or other flags, which may not match the final deployment state. # This can easily lead to misplaced quant and dequant nodes. def forward(self, input, gt): conv_out = self.conv(input) if self.training: return self.loss(conv_out, gt) else: return self.sigmoid(conv_out) # Extract deployment logic and only apply prepare to this part. def forward_infer(self, input, gt): conv_out = self.conv(input) return self.sigmoid(conv_out), conv_out # Non-deployment logic is handled externally and excluded from prepare. def forward(self, input, gt): sig_out, conv_out = self.forward_infer(input, gt) if self.training: return self.loss(conv_out, gt) else: return sig_out

If it’s not feasible to extract clean deployment logic due to readability or maintainability concerns, then the following checks should be performed:

  1. Check for missing or unexpected keys when loading the checkpoint. Missing or unexpected quantization parameters may indicate that some deployment logic was not quantized or some non-deployment logic was incorrectly quantized. Missing or unexpected model parameters may indicate a mismatch between the forward logic used during prepare and the one used when the checkpoint was generated.

  2. Check whether the FX graph generated by multiple prepare runs is consistent. Inconsistent FX graphs suggest that the forward logic differs across runs, and it should be verified whether this is intended.

Model Check

When example_inputs is provided, prepare will perform a model structure check by default. If the check completes, a model_check_result.txt file can be found in the running directory. If the check fails, you need to modify the model based on the warning prompts or call horizon_plugin_pytorch.utils.check_model.check_qat_model separately to check the model. The check process is the same as check_qat_model in the debug tool, and the analysis of the result file is detailed in the check_qat_model related documentation.