Quantization Accuracy Tuning Practice
This chapter introduces the entire process of accuracy tuning through a practical example. Please read the Quantization Accuracy Tuning Guide chapter firstly to understand the relevant theoretical knowledge and tool usage.
Model Structure and Quantization Configuration Check
After completing the necessary adaptations for QAT, run the program. A model_check_result.txt file will be generated in the running directory. Firstly, check this file.
Operator Fusion
Start by reviewing the operator fusion status and check if there are any operators that have not fused as expected.
Fusable modules are listed below:
name type
------------------------------------------------ --------------------------------------------------------------------------
model.view_transformation.input_proj.0.0 <class 'horizon_plugin_pytorch.nn.qat.conv2d.Conv2d'>
model.view_transformation._generated_add_0 <class 'horizon_plugin_pytorch.nn.qat.functional_modules.FloatFunctional'>
name type
------------------------------------------------ --------------------------------------------------------------------------
model.view_transformation.input_proj.1.0 <class 'horizon_plugin_pytorch.nn.qat.conv2d.Conv2d'>
model.view_transformation._generated_add_2 <class 'horizon_plugin_pytorch.nn.qat.functional_modules.FloatFunctional'>
name type
------------------------------------------------ --------------------------------------------------------------------------
model.view_transformation.input_proj.2.0 <class 'horizon_plugin_pytorch.nn.qat.conv2d.Conv2d'>
model.view_transformation._generated_add_4 <class 'horizon_plugin_pytorch.nn.qat.functional_modules.FloatFunctional'>
name type
------------------------------------------------ --------------------------------------------------------------------------
model.view_transformation.input_proj.3.0 <class 'horizon_plugin_pytorch.nn.qat.conv2d.Conv2d'>
model.view_transformation._generated_add_6 <class 'horizon_plugin_pytorch.nn.qat.functional_modules.FloatFunctional'>
name type
------------------------------------------------ --------------------------------------------------------------------------
model.view_transformation.input_proj.4.0 <class 'horizon_plugin_pytorch.nn.qat.conv2d.Conv2d'>
model.view_transformation._generated_add_8 <class 'horizon_plugin_pytorch.nn.qat.functional_modules.FloatFunctional'>
name type
------------------------------------------------ --------------------------------------------------------------------------
model.view_transformation.input_proj.5.0 <class 'horizon_plugin_pytorch.nn.qat.conv2d.Conv2d'>
model.view_transformation._generated_add_10 <class 'horizon_plugin_pytorch.nn.qat.functional_modules.FloatFunctional'>
Check model.view_transformation.input_proj module.
class BevFormer(BaseModule):
def process_input(self, feats):
...
for cam_idx in range(num_cameras):
# Due to the presence of dynamic code blocks, it is necessary to use the dynamic_block annotation for proper fusion.
with Tracer.dynamic_block(self, "bevformer_process_input"):
src = cur_fpn_lvl_feat[cam_idx]
bs, _, h, w = src.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
src = self.input_proj[str(cam_idx)](src)
src = src + self.cams_embeds[cam_idx][None, :, None, None]
src = src + self.level_embeds[feat_idx][None, :, None, None]
src = src.flatten(2).transpose(1, 2) # B, C, H, W --> B, C, H*W --> B, H*W, C
src_flatten.append(src)
Shared Modules
Next, check the shared modules. Modules where called times > 1 should be split.
Each module called times:
name called times
--------------------------------------------------------------------------------------- --------------
...
model.map_head.sparse_head.decoder.gen_sineembed_for_position.div.reciprocal 8
model.map_head.sparse_head.decoder.gen_sineembed_for_position.div.mul 8
model.map_head.sparse_head.decoder.gen_sineembed_for_position.sin_model.sin 8
model.map_head.sparse_head.decoder.gen_sineembed_for_position.cos_model.cos 8
model.map_head.sparse_head.decoder.gen_sineembed_for_position.stack 8
model.map_head.sparse_head.decoder.gen_sineembed_for_position.cat 4
model.map_head.sparse_head.decoder.gen_sineembed_for_position.mul 8
model.map_head.sparse_head.decoder.gen_sineembed_for_position.dim_t_quant 4
...
After calibration or QAT training, if the accuracy is poor, you can use the debug tool to observe the statistics of shared operators in compare_per_layer. The base model is the floating-point model, and the analysis model is the calibration model. Taking two calls of model.map_head.sparse_head.decoder.gen_sineembed_for_position.div.mul as an example, the maximum value is 128 * 0.0446799 ≈ 5.719. For the first time, the output range is clearly smaller than [-5.719, 5.719], and the error is relatively small. For the second time, however, the output range exceeds [-5.719, 5.719], causing the values to be truncated, which leads to a larger error. The difference in the value ranges between the two uses also results in inaccurate scale statistics.
+------+-------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------+--------------------------------+---------------+------------+------------------+-------------------+------------------+-------------------+
| | mod_name | base_op_type | analy_op_type | shape | quant_dtype | qscale | base_model_min | analy_model_min | base_model_max | analy_model_max |
|------+-------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------+--------------------------------+---------------+------------+------------------+-------------------+------------------+-------------------+
...
| 1227 | model.map_head.sparse_head.decoder.gen_sineembed_for_position.div | horizon_plugin_pytorch.nn.div.Div | horizon_plugin_pytorch.nn.qat.functional_modules.FloatFunctional.mul | torch.Size([1, 1600, 128]) | qint8 | 0.0446799 | 0.0002146 | 0.0000000 | 4.5935526 | 4.5567998 |
...
| 1520 | model.map_head.sparse_head.decoder.gen_sineembed_for_position.div | horizon_plugin_pytorch.nn.div.Div | horizon_plugin_pytorch.nn.qat.functional_modules.FloatFunctional.mul | torch.Size([1, 1600, 128]) | qint8 | 0.0446799 | 0.0000000 | 0.0000000 | 6.2831225 | 5.7190272 |
...
Check model.map_head.sparse_head.decoder.gen_sineembed_for_position module.
class AnchorDeformableTransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, return_intermediate=False):
...
# Construct different gen_sineembed_for_position and use them separately.
for i in range(len(self.layers)):
self.add_module(
"gen_sineembed_for_position%d" % (i), PositionEmbedding()
)
def forward(...):
...
for lid, layer in enumerate(self.layers):
ref_shape = reference_points.shape
assert ref_shape[-1] == 2
reference_points_reshape = reference_points.view(ref_shape[0], -1, 2)
query_sine_embed = getattr(self, "gen_sineembed_for_position%d" % (lid))(reference_points_reshape)
...
QConfig Configuration Errors
input dtype statistics:
+----------------------------------------------------------------------------+-----------------+---------+----------+----------+
| module type | torch.float32 | qint8 | qint16 | qint32 |
|----------------------------------------------------------------------------+-----------------+---------+----------+----------|
| <class 'horizon_plugin_pytorch.nn.qat.stubs.QuantStub'> | 290 | 15 | 0 | 0 |
| <class 'horizon_plugin_pytorch.nn.qat.conv2d.ConvReLU2d'> | 0 | 6 | 0 | 0 |
| <class 'horizon_plugin_pytorch.nn.qat.conv2d.Conv2d'> | 0 | 228 | 0 | 0 |
| <class 'horizon_plugin_pytorch.nn.qat.gelu.GELU'> | 0 | 63 | 0 | 0 |
| <class 'horizon_plugin_pytorch.nn.qat.functional_modules.FloatFunctional'> | 3 | 425 | 725 | 140 |
| <class 'horizon_plugin_pytorch.nn.qat.batchnorm.BatchNorm2d'> | 0 | 9 | 0 | 0 |
| <class 'horizon_plugin_pytorch.nn.qat.linear.Linear'> | 5 | 117 | 9 | 72 |
| <class 'horizon_plugin_pytorch.nn.qat.segment_lut.SegmentLUT'> | 0 | 64 | 125 | 0 |
| <class 'torch.nn.modules.dropout.Dropout'> | 0 | 53 | 0 | 0 |
| <class 'horizon_plugin_pytorch.nn.qat.linear.LinearReLU'> | 1 | 17 | 0 | 28 |
| <class 'horizon_plugin_pytorch.nn.qat.conv_transpose2d.ConvTranspose2d'> | 0 | 1 | 0 | 0 |
| <class 'horizon_plugin_pytorch.nn.qat.relu.ReLU'> | 0 | 8 | 0 | 56 |
| <class 'horizon_plugin_pytorch.nn.qat.linear.LinearAdd'> | 0 | 4 | 0 | 4 |
| <class 'horizon_plugin_pytorch.nn.qat.stubs.DeQuantStub'> | 0 | 8 | 0 | 4 |
| total | 299 | 1018 | 859 | 304 |
+----------------------------------------------------------------------------+-----------------+---------+----------+----------+
output dtype statistics:
+----------------------------------------------------------------------------+-----------------+---------+----------+----------+
| module type | torch.float32 | qint8 | qint16 | qint32 |
|----------------------------------------------------------------------------+-----------------+---------+----------+----------|
| <class 'horizon_plugin_pytorch.nn.qat.stubs.QuantStub'> | 0 | 123 | 182 | 0 |
| <class 'horizon_plugin_pytorch.nn.qat.conv2d.ConvReLU2d'> | 0 | 6 | 0 | 0 |
| <class 'horizon_plugin_pytorch.nn.qat.conv2d.Conv2d'> | 0 | 228 | 0 | 0 |
| <class 'horizon_plugin_pytorch.nn.qat.gelu.GELU'> | 0 | 63 | 0 | 0 |
| <class 'horizon_plugin_pytorch.nn.qat.functional_modules.FloatFunctional'> | 0 | 341 | 716 | 64 |
| <class 'horizon_plugin_pytorch.nn.qat.batchnorm.BatchNorm2d'> | 0 | 9 | 0 | 0 |
| <class 'horizon_plugin_pytorch.nn.qat.linear.Linear'> | 0 | 85 | 18 | 100 |
| <class 'horizon_plugin_pytorch.nn.qat.segment_lut.SegmentLUT'> | 0 | 55 | 134 | 0 |
| <class 'torch.nn.modules.dropout.Dropout'> | 0 | 53 | 0 | 0 |
| <class 'horizon_plugin_pytorch.nn.qat.linear.LinearReLU'> | 0 | 18 | 0 | 28 |
| <class 'horizon_plugin_pytorch.nn.qat.conv_transpose2d.ConvTranspose2d'> | 0 | 1 | 0 | 0 |
| <class 'horizon_plugin_pytorch.nn.qat.relu.ReLU'> | 0 | 8 | 0 | 56 |
| <class 'horizon_plugin_pytorch.nn.qat.linear.LinearAdd'> | 0 | 4 | 0 | 0 |
| <class 'horizon_plugin_pytorch.nn.qat.stubs.DeQuantStub'> | 12 | 0 | 0 | 0 |
| total | 12 | 994 | 1050 | 248 |
+----------------------------------------------------------------------------+-----------------+---------+----------+----------+
Each layer out qconfig:
+---------------------------------------------------------------------------------------------+----------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-----------------+-------------------+--------------+
| Module Name | Module Type | Input dtype | out dtype | ch_axis | observer |
|---------------------------------------------------------------------------------------------+----------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-----------------+-------------------+--------------|
...
| model.obstacle_head.reg_branches.0.0 | <class 'horizon_plugin_pytorch.nn.qat.linear.Linear'> | ['qint8'] | ['qint32'] | -1 | MixObserver |
...
| model.obstacle_head.reg_out_dequant0 | <class 'horizon_plugin_pytorch.nn.qat.stubs.DeQuantStub'> | ['qint8'] | [torch.float32] | qconfig = None | |
There are the following issues:
-
There is an operator with a qint32 output. High precision output is represented as torch.float32 in the table here, not qint32. This is likely due to a qconfig configuration error. The qconfig for model.obstacle_head.reg_branches.0.0 needs to be corrected.
-
The DequantStub has a qint8 input, which indicates that there may be some parts of the model that are not using high precision output. The template only supports automatic high-precision output for GEMM-like operators before dequant. You need to check if those operators before dequant are GEMM-like operators.
-
Apart from quant and dequant, a small number of operators have fp32 inputs, indicating that some parts may be missing quant nodes. If you compile the model, you'll find that it is not fully quantized; some operators are fallbacked to CPU. You can check the input-output types of each layer in detail to identify which operators need to insert QuantStub before.
Mixed Precision Tuning
The entire process first performs all-int16 precision tuning. This stage is used to confirm the model's maximum achievable precision, troubleshoot tool usage issues, and identify modules that are not quantization-friendly. Once the all-int16 precision meets the requirements, proceed with all-int8 precision tuning. If the precision is not up to standard, proceed with int8 / int16 mixed precision tuning, gradually increasing the proportion of int16 based on the all-int8 model. During this phase, a balance between precision and performance must be struck, aiming to find the quantization configuration that provides the best performance while maintaining the required precision.
All-INT16 Precision Tuning
| Branch | float | target (loss < 1%) | all-int16 calibration |
|---|
| dynamic | 73.6 | 72.864 | 0 |
| static | 55.5 | 54.945 | 0 |
For first calibration, precision collapses. Normally, the all-int16 quantization should not cause a precision collapse. You need to debug the outputs that are causing the precision drop. Here, you can choose either the static or dynamic branch's output and apply the following modifications. The example here includes some post-processing (sigmoid), and the debug is specifically performed on the last layer output of the static branch (map).
class BevNet(pl.LightningModule):
def forward(self, batch):
...
# return image, calibration, gt_dict, mask_dict, pred_dict
# Only debug the outputs that are causing the precision drop. You could debug all outputs, but it would lack focus and is slower.
return pred_dict['map']['preds']['layer_3']
class MapQRHead(BaseModule):
def forward_branch(self, hs, init_reference, inter_references):
outputs = []
for lvl in range(hs.shape[0]):
...
cls_out = getattr(self, "cls_out_dequant%d" % (lvl))(self.cls_branches[lvl](hs[lvl]))
pts_out = getattr(self, "pts_out_dequant%d" % (lvl))(self.pts_branches[lvl](hs[lvl])).view(bs, -1, 2)
# It belongs to the post-processing logic. Although it is placed after dequant, it should be added.
pts_out = pts_out.sigmoid()
pts_out = pts_out.view(bs, self.num_polyline, self.num_pts_per_polyline, -1)
y = torch.cat([cls_out, pts_out, attr_out[0], attr_out[1], attr_out[2], attr_out[3], attr_out[4]], dim=-1) # bs,num_polyline,num_pts,n_attr
outputs.append(y)
return outputs
The debug results are as follows:
op_name sensitive_type op_type L1
--------------------------------------------------------------------------------------- ---------------- -------------------------------------------------------------------------- ----------
model.view_transformation.transformer.layers.0.quant activation <class 'horizon_plugin_pytorch.nn.qat.stubs.QuantStub'> 0.580483
model.obstacle_head.decoder.layers.0.cross_attn.quant_normalizer activation <class 'horizon_plugin_pytorch.nn.qat.stubs.QuantStub'> 0.130977
...
From the perspective of quantization sensitivity, the primary contributors are several QuantStub operators. In the compare_per_layer results, analyze the type of quantization error associated with these sensitive operators.
+------+-------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------+--------------------------------+---------------+------------+------------+--------------+------------+------------+-------------+-------------+-------------------------------------------------+------------------+-------------------+------------------+-------------------+-------------------+--------------------+------------------+-------------------+-----------------+-------------------+
| | mod_name | base_op_type | analy_op_type | shape | quant_dtype | qscale | Cosine | MSE | L1 | KL | SQNR | Atol | Rtol | base_model_min | analy_model_min | base_model_max | analy_model_max | base_model_mean | analy_model_mean | base_model_var | analy_model_var | max_atol_diff | max_qscale_diff |
|------+-------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------+--------------------------------+---------------+------------+------------+--------------+------------+------------+-------------+-------------+-------------------------------------------------+------------------+-------------------+------------------+-------------------+-------------------+--------------------+------------------+-------------------+-----------------+-------------------|
...
| 791 | model.view_transformation.transformer.layers.0.quant | horizon_plugin_pytorch.quantization.stubs.QuantStub | horizon_plugin_pytorch.nn.qat.stubs.QuantStub | torch.Size([1, 1875, 24, 2]) | qint8 | 0.7764707 | 0.9999968 | 0.1134158 | 0.3205845 | 0.0000081 | 46.6785774 | 0.3882294 | 1.0000000 | -99.0000000 | -98.6117706 | 0.9999269 | 0.7764707 | -53.0977783 | -52.8312225 | 2459.8188477 | 2446.8227539 | 0.3882294 | 0.4999923 |
...
| 883 | model.obstacle_head.decoder.layers.0.cross_attn.quant_normalizer | horizon_plugin_pytorch.quantization.stubs.QuantStub | horizon_plugin_pytorch.nn.qat.stubs.QuantStub | torch.Size([1, 7500, 8, 32]) | qint8 | 0.0522360 | 0.3601017 | 0.9630868 | 0.6762735 | 0.0000729 | -0.6995326 | 12.4819937 | 601413.5625000 | -9.2183294 | -6.5295014 | 10.2716885 | 6.6339736 | -0.0177280 | -0.0255637 | 0.8194664 | 0.6810485 | 12.4819937 | 238.9537970 |
...
For model.view_transformation.transformer.layers.0.quant:
-
The quantization type is still qint8, which indicates that the int16 configuration has not taken effect. You need to check, aside from the setter, whether int8 qconfig has been manually set.
-
The scale is 0.7764707, and the representable floating-point range is from 0.776 * (-128) = -99.38 to 0.776 * 127 = 98.61. Considering the physical meaning here, the input range for this quant should be from -100 to 1, which results in some truncation errors. Additionally, the range of values is relatively large, there will also be significant rounding errors for int8. Therefore, it is necessary to switch to int16 quantization and set a fixed scale of 100 / 32768 according to the input range.
For model.obstacle_head.decoder.layers.0.cross_attn.quant_normalizer, the issue is similar.
After the modifications are completed, re-run the calibration and debugging. A noticeable improvement can be observed, and the sensitivity of the operators that were previously ranked higher in sensitivity has significantly decreased. Next, we will proceed with QAT training.
| Branch | float | target (loss < 1%) | all-int16 calibration |
|---|
| dynamic | 73.6 | 72.864 | 73.4 |
| static | 55.5 | 54.945 | 54 |
During the initial QAT training, the precision collapsed, and the loss curve did not converge.
| Branch | float | target (loss < 1%) | all-int16 calibration | all-int16 qat | finetune float |
|---|
| dynamic | 73.6 | 72.864 | 73.4 | 0 | 0 |
| static | 55.5 | 54.945 | 54 | 0 | 0 |

Attempt the following:
-
Since the calibration precision is very good, it is assumed that the model no longer has any issues with the quantization tool usage. The first step was to adjust hyperparameters like learning rate (lr), weight decay, etc., but this still did not resolve the issue.
-
Accuracy debugging. We can find that model.map_head.sparse_head.decoder.gen_sineembed_for_position0.dim_t_quant remains highly sensitive with significant quantization error with int16 quantization.
op_name sensitive_type op_type L1 quant_dtype
----------------------------------------------------------------------------- ---------------- -------------------------------------------------------------------------- ---------- -------------
model.map_head.sparse_head.decoder.gen_sineembed_for_position0.dim_t_quant activation <class 'horizon_plugin_pytorch.nn.qat.stubs.QuantStub'> 0.82213 qint16
model.map_head.sparse_head.decoder.gen_sineembed_for_position1.dim_t_quant activation <class 'horizon_plugin_pytorch.nn.qat.stubs.QuantStub'> 0.184159 qint16
model.map_head.sparse_head.pts_branches.0.6 activation <class 'horizon_plugin_pytorch.nn.qat.linear.Linear'> 0.131423 qint16
model.map_head.sparse_head.decoder.gen_sineembed_for_position2.dim_t_quant activation <class 'horizon_plugin_pytorch.nn.qat.stubs.QuantStub'> 0.111852 qint16
model.map_head.sparse_head.pts_branches.1.6 activation <class 'horizon_plugin_pytorch.nn.qat.linear.Linear'> 0.0930651 qint16
model.map_head.sparse_head.sigmoid activation <class 'horizon_plugin_pytorch.nn.qat.segment_lut.SegmentLUT'> 0.0887103 qint16
model.map_head.sparse_head.pts_branches.2.6 activation <class 'horizon_plugin_pytorch.nn.qat.linear.Linear'> 0.0728263 qint16
model.map_head.sparse_head.reference_points.2 activation <class 'horizon_plugin_pytorch.nn.qat.linear.Linear'> 0.0689369 qint16
Reviewing the compare_per_layer results, the quantization range is 32767 * 0.2642754 ≈ 8659.51. Although the value is relatively large, there is no noticeable truncation error.
+------+-------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------+----------------------------------+---------------+------------+------------+-------------+-----------+------------+-------------+-------------+-------------------------------------------------+------------------+-------------------+------------------+-------------------+-------------------+--------------------+------------------+-------------------+-----------------+-------------------+
| | mod_name | base_op_type | analy_op_type | shape | quant_dtype | qscale | Cosine | MSE | L1 | KL | SQNR | Atol | Rtol | base_model_min | analy_model_min | base_model_max | analy_model_max | base_model_mean | analy_model_mean | base_model_var | analy_model_var | max_atol_diff | max_qscale_diff |
|------+-------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------+----------------------------------+---------------+------------+------------+-------------+-----------+------------+-------------+-------------+-------------------------------------------------+------------------+-------------------+------------------+-------------------+-------------------+--------------------+------------------+-------------------+-----------------+-------------------|
...
| 1296 | model.map_head.sparse_head.decoder.gen_sineembed_for_position0.dim_t_quant | torch.ao.quantization.stubs.QuantStub | horizon_plugin_pytorch.nn.qat.stubs.QuantStub | torch.Size([128]) | qint16 | 0.2642754 | 0.9999999 | 0.0058058 | 0.0670551 | 0.0000000 | 89.0683746 | 0.1328125 | 0.0845878 | 1.0000000 | 1.0571015 | 8659.6435547 | 8659.5107422 | 1009.3834229 | 1009.3997803 | 3694867.5000000 | 3694819.5000000 | 0.1328125 | 0.5025535 |
...
Print dim_t. It is a non-uniform distribution, which is not friendly to linear quantization. This distribution produces significant errors when the values are small.

dim_t is a denominator in a division operation. When dim_t is small, the rounding errors are amplified by the division. Since dim_t is fixed, and we know that the rounding errors from the smaller values have a larger impact, we can simply divide it into two groups for quantization. After the division operation, we can concatenate the groups together. This way, we ensure that the scale for the first group becomes smaller, which helps reduce the rounding errors.
class PositionEmbedding(torch.nn.Module):
def forward(self, pos_tensor):
dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
dim_t = 10000 ** (2 * (dim_t // 2) / 128)
# dim_t = self.dim_t_quant(dim_t)
# pos_x = x_embed[:, :, None] / dim_t
dim_t_1 = dim_t[:32] # 32 can also be increased or decreased.
dim_t_2 = dim_t[32:]
pos_x_1 = x_embed_1[:, :, None] / dim_t_1
pos_x_2 = x_embed_2[:, :, None] / dim_t_2
pos_x = torch.cat([pos_x_1, pos_x_2])
At this point, after the adjustments, the all-int16 calibration precision shows some improvement, but the QAT precision still collapses.
| Branch | float | target (loss < 1%) | all-int16 calibration | all-int16 qat | finetune float |
|---|
| dynamic | 73.6 | 72.864 | 73.4 | 0 | 0 |
| static | 55.5 | 54.945 | 54.7 | 0 | 0 |
Start troubleshooting the training pipeline:
-
Using the QAT pipeline and training parameters, finetune the floating-point model and compare it with the floating-point training. It was found that the loss is still large, and the precision still collapses.
-
Set the learning rate to 0 and finetune the floating-point model, the precision still collapses. This led to the conclusion that the issue is not related to QAT. The problem lies in the fact that the code used for QAT training is not aligned with the code used for floating-point training.
After carefully comparing the modification records and resolving the alignment issues, the precision now meets the requirements.
| Branch | float | target (loss < 1%) | all-int16 calibration | all-int16 qat |
|---|
| dynamic | 73.6 | 72.864 | 73.4 | 74.1 |
| static | 55.5 | 54.945 | 54.7 | 55.3 |
All-INT8 Precision Tuning
Because we have already identified some sensitive operators in the all-int16 debug, during the all-int8 tuning, we can directly set these operators to int16 (or alternatively assume that if not identified, they remain int8, and they can also be detected in the int8 precision debug). The rest of the operators use int8. With such a quantization configuration, both calibration and QAT precision collapse.
| Branch | float | target (loss < 1%) | all-int16 calibration | all-int16 qat | all-int8 calibration | all-int8 qat |
|---|
| dynamic | 73.6 | 72.864 | 73.4 | 74.1 | 0 | 0 |
| static | 55.5 | 54.945 | 54.7 | 55.3 | 0 | 0 |
Int16 operators need to be added after precision debugging.
INT8 / INT16 Mixed Precision Tuning
Based on the results of all-int8 debug, set the operators with the most significant sensitivity as int16.
op_name sensitive_type op_type L1 quant_dtype
--------------------------------------------------------------------------------------- ---------------- -------------------------------------------------------------------------- --------- -------------
model.view_transformation.ref_point_quant activation <class 'horizon_plugin_pytorch.nn.qat.stubs.QuantStub'> 1.79371 qint8
model.map_head.sparse_head.decoder.gen_sineembed_for_position0.dim_t_quant activation <class 'horizon_plugin_pytorch.nn.qat.stubs.QuantStub'> 1.46594 qint8
model.map_head.sparse_head.decoder.reg_branch_output_add0 activation <class 'horizon_plugin_pytorch.nn.qat.functional_modules.FloatFunctional'> 0.352401 qint8
model.map_head.sparse_head.decoder.gen_sineembed_for_position1.dim_t_quant activation <class 'horizon_plugin_pytorch.nn.qat.stubs.QuantStub'> 0.246953 qint8
model.view_transformation.transformer.layers.0.linear2 activation <class 'horizon_plugin_pytorch.nn.qat.linear.Linear'> 0.22353 qint8
model.map_head.sparse_head.decoder.gen_sineembed_for_position2.dim_t_quant activation <class 'horizon_plugin_pytorch.nn.qat.stubs.QuantStub'> 0.214513 qint8
model.map_head.sparse_head.decoder.reg_branch_output_add1 activation <class 'horizon_plugin_pytorch.nn.qat.functional_modules.FloatFunctional'> 0.185211 qint8
model.map_head.sparse_head.sigmoid activation <class 'horizon_plugin_pytorch.nn.qat.segment_lut.SegmentLUT'> 0.1826 qint8
model.map_head.sparse_head.decoder.gen_sineembed_for_position3.dim_t_quant activation <class 'horizon_plugin_pytorch.nn.qat.stubs.QuantStub'> 0.166937 qint8
Finally, setting the top 0.5% most sensitive operators as int16, the precision meets the requirements. Next, we can further finetune parameters such as learning rate and weight decay to achieve QAT precition target with fewer int16 operators.
| Branch | float | target (loss < 1%) | all-int8 calibration | all-int8 qat | int8 + ref_point_quant int16 calibration | int8 + ref_point_quant int16 qat | int8 + sensitivity top 0.5% int16 calibration | int8 + sensitivity top 0.5% int16 qat |
|---|
| dynamic | 73.6 | 72.864 | 0 | 0 | 70.9 | 71.3 | 71 | 73.1 |
| static | 55.5 | 54.945 | 0 | 0 | 7.5 | 27.8 | 53.7 | 55.1 |
Review of the Process
-
Prioritize checking sensitivity (outputs related to accuracy drops). After reviewing the sensitivity, proceed to compare_per_layer results. Once a sensitive operator is identified, confirm whether the error is due to rounding or truncation through statistical analysis. Then, adjust the quantization configuration accordingly.
-
Int8 / Int16 tuning is completed by the sensitivity setter. You only need to set the ratio, and it’s not too difficult for now. The focus of precision tuning should be on all-int16 tuning, where various issues such as usage problems, modules that are not quantization-friendly, and other complex challenges need to be addressed.
-
All-int16 calibration should achieve a normal precision. Precision collapse indicates usage issues. The process is continuous debugging, analyzing the top sensitive operators as described above, and modifying the quantization configuration until target precision is achieved. Some modifications may not immediately show accuracy improvements, but you can observe a reduction in sensitivity for the modified operators.