Quantized Awareness Training Guide

The quantized awareness training is performed by inserting some pseudo-quantized nodes into the model, so as to minimize the loss of accuracy when the model obtained through quantized awareness training is converted into a fixed-point model. The quantized awareness training is no different from traditional model training in that one can start from scratch, build a pseudo-quantized model, and then train on that pseudo-quantized model. Due to the limitations of the deployed hardware platform, it is challenging to understand these limitations and build a pseudo-quantization model based on them. The quantized awareness training tool reduces the challenges of developing quantized models by automatically inserting pseudo-quantization operators into the provided floating-point model based on the limitations of the deployment platform.

The quantized awareness training is generally more difficult than the training of pure floating-point models due to the various restrictions imposed. The goal of the quantized awareness training tool is to reduce the difficulty of quantized awareness training and to reduce the engineering difficulty of quantized model deployment.

Process and Example

Although our quantized awareness training tool does not mandate that you provide a pre-trained floating-point model at the outset, experience has shown that starting quantized awareness training from a pre-trained high-precision floating-point model generally significantly reduces the difficulty of training.

# convert the model to QAT state qat_model = prepare( float_model, example_input, qconfig_setter = horizon.quantization.qconfig_template.default_qat_qconfig_setter, ).to(device) # load the quantization parameters in the Calibration model qat_model.load_state_dict(calib_model.state_dict()) # perform quantized awareness training # as a filetune process, quantized awareness training generally requires setting a small learning rate optimizer = torch.optim.SGD( qat_model.parameters(), lr=0.0001, weight_decay=2e-4 ) for nepoch in range(epoch_num): # note the method of controlling the training state of the QAT model here qat_model.train() set_fake_quantize(qat_model, FakeQuantState.QAT) train_one_epoch( qat_model, nn.CrossEntropyLoss(), optimizer, None, train_data_loader, device, ) # note the method of controlling the eval state of the QAT model here qat_model.eval() set_fake_quantize(qat_model, FakeQuantState.VALIDATION) # test qat model accuracy top1, top5 = evaluate( qat_model, eval_data_loader, device, ) print( "QAT model: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format( top1.avg, top5.avg ) ) # test quantized model accuracy qat_hbir_model = horizon_plugin_pytorch.quantization.hbdk4.export( qat_model. example_input ) quantized_hbir_model = hbdk4.compiler.convert(qat_hbir_model) top1, top5 = evaluate( quantized_hbir_model, eval_data_loader, ) print( "Quantized model: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format( top1.avg, top5.avg ) )
Attention

Due to the underlying limitations of the deployment platform, the QAT model cannot fully represent the final on-board accuracy, please make sure to monitor the quantized model accuracy to ensure that the quantized model accuracy is normal, otherwise the model on-board dropout problem may occur.

As can be seen from the above sample code, there are two additional steps in quantized awareness training compared to traditional pure floating-point model training:

  1. prepare. The goal of this step is to transform the floating-point network and insert pseudo-quantized nodes.

  2. Load the Calibration model parameters. A better initialization is obtained by loading the pseudo-quantization parameters obtained from Calibration.

    Modifications of state_dict should not only focus on the keys and values, but also on the _metadata. Below is an example of copying a state_dict:

    new_state_dict = OrderedDict() for k, v in state_dict.items(): new_state_dict[k] = v if hasattr(state_dict, "_metadata"): new_state_dict._metadata = copy.deepcopy(state_dict._metadata)
Attention

The compatibility of operators depends on the _version variable of torch.nn.Module, which is stored in state_dict._metadata. Please ensure that the _metadata is preserved during the process of saving or loading state_dict, as its absence may lead to compatibility issues.

At this point, the construction of the pseudo-quantized model and the initialization of the parameters are completed, and then the regular training iterations and model parameter updates can be performed, and the quantized model accuracy can be monitored.

To meet the requirements of segmented deployment or to align with float training strategies, it may be necessary to freeze certain parts of the model during training. You can refer to the following code to perform the freezing:

from horizon_plugin_pytorch.quantization import freeze_qat_module # Model weights / quantization parameters will be fixed, and all operators will be set to eval mode. # Ensure that the freeze_qat_module interface is called after invoking interfaces like train(), eval(), or set_fake_quantize(), which may change the model state. freeze_qat_module(model)

Pseudo-quantized Operator

The main difference between the quantized awareness training and the traditional floating-point model's training is the insertion of pseudo-quantization operators, and, as different quantized awareness training algorithms are also represented by pseudo-quantization operators, here we take a brief introduce the pseudo-quantization operators.

Note

Since the BPU only supports symmetric quantization, here we take the symmetric quantization as an example.

Pseudo-quantization Process

Take the int8 quantized awareness training as an example, in general, the pseudo-quantization operator is computed as: fakequantx=clip(round(x/scale)128,127)scalefake_quant_x = clip(round(x / scale),-128, 127) * scale.

Similar to Conv2d, which optimizes the weight and bias parameters through training, the pseudo-quantization operator needs to be trained to optimize the scale parameter. However, the gradient of round as a step function is 0, which makes it impossible to train the pseudo-quantization operator by backpropagation of the gradient directly. To solve this problem, there are usually two solutions: a statistical-based approach and a learning-based approach.

Statistical-based Approach

The goal of quantization is to uniformly map the floating point numbers in Tensor to the range [-128, 127] represented by int8 via the scale parameter. Since the mapping is uniform, it is easy to see how scale is calculated:

def compute_scale(x: Tensor): xmin, xmax = x.max(), maxv = x.min() return max(xmin.abs(), xmax.abs()) / 128.0

Due to the uneven distribution of data in Tensor and the outlier problem, different methods for calculating xmin and xmax have been developed, you can refer to the relevant introductions of interfaces such as MinMaxObserver in the Observer Parameters section.

Please refer to QConfig in Detail for the usage in the tool.

Learning-based Approach

Although the gradient of round is 0, the researcher found experimentally that in this scenario, if the gradient is directly set to 1, the model can also be made to converge to the expected accuracy.

def round_ste(x: Tensor): return (x.round() - x).detach() + x

Please refer to Definition of FakeQuantize for the usage in the tool.

If you are interested in learning more, you can refer to the paper Learned Step Size Quantization.