Prepare is the process of converting a floating-point model into a pseudo-quantized model. This process involves several key steps:
Operator Replacement: Some torch function operators (such as F.interpolate) need to have FakeQuantize nodes inserted during quantization. Therefore, these operators are replaced with corresponding Module type implementations (horizon_plugin_pytorch.nn.Interpolate) to place the FakeQuantize nodes inside this Module. The model before and after replacement is equivalent.
Operator Fusion: BPU supports fusing specific computational patterns, where the intermediate results of fused operators are represented with high precision. Therefore, we replace multiple operators to be fused with a single Module to prevent quantizing the intermediate results. The model before and after fusion is also equivalent.
Operator Conversion: Floating-point operators are replaced with QAT (Quantized Awareness Training) operators. According to the configured qconfig, QAT operators will add FakeQuantize nodes at the input/output/weights.
To keep converted QAT operators work as expected, please ensure that no further modifications are made to the model after calling prepare. For example, converting an unfused BN to a sync BN after prepare may cause the QAT BN to be modified again. The conversion to sync BN should be done before calling prepare.
Accuracy Description
The model after prepare differs from the floating-point model in terms of computational logic in the following aspects:
Pseudo-quantization nodes are added to the model. For a very small number of operators (such as reciprocal) whose outputs may have extremely large values, to adapt to quantization, their outputs will be clipped to a reasonable range by default.
The above operations will cause changes in the numerical values of the model.
The usage of the prepare interface is as follows:
The prepare method includes JIT_STRIP and EAGER. JIT_STRIP belongs to Graph Mode, while EAGER belongs to PrepareMethod.EAGER. Their comparison is as follows:
| method | Principle | Advantages | Disadvantages |
|---|---|---|---|
| PrepareMethod.JIT_STRIP(Graph Mode) | Use hooks and subclass tensor to get the graph structure, performing operator replacement/operator fusion on the original forward. | Fully automatic, minimal code modification, hides many detail issues, easy to debug. | Dynamic code blocks need special handling. |
| PrepareMethod.EAGER(Eager Mode) | Does not sense the graph structure. operator replacement/operator fusion needs to be done manually. | Flexible usage, controllable process, easy to debug and handle various special needs. | Requires more manual operations, more code modifications, high learning cost. |
Currently, JIT_STRIP is our recommended method. The JIT_STRIP will identify and skip pre-process and post-process based on the positions of QuantStub and DequantStub in the model.
When using the graph-based prepare method, it's recommended that the model only contains deployment logic.
If it’s not feasible to extract clean deployment logic due to readability or maintainability concerns, then the following checks should be performed:
Check for missing or unexpected keys when loading the checkpoint. Missing or unexpected quantization parameters may indicate that some deployment logic was not quantized or some non-deployment logic was incorrectly quantized. Missing or unexpected model parameters may indicate a mismatch between the forward logic used during prepare and the one used when the checkpoint was generated.
Check whether the FX graph generated by multiple prepare runs is consistent. Inconsistent FX graphs suggest that the forward logic differs across runs, and it should be verified whether this is intended.
When example_inputs is provided, prepare will perform a model structure check by default. If the check completes, a model_check_result.txt file can be found in the running directory. If the check fails, you need to modify the model based on the warning prompts or call horizon_plugin_pytorch.utils.check_model.check_qat_model separately to check the model. The check process is the same as check_qat_model in the debug tool, and the analysis of the result file is detailed in the check_qat_model related documentation.