Algorithm Model QAT + Deployment Quick Start
Basic Process
The basic process for using the Quantized Awareness Training Tool is as follows:

The following is an example of the MobileNetV2 model from torchvision to introduce you to each stage of the process.
We used the cifar-10 dataset instead of the ImageNet-1K dataset due to the speed of execution of the process display.
import os
import copy
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch import Tensor
from torch.quantization import DeQuantStub
from torchvision.datasets import CIFAR10
from torchvision.models.mobilenetv2 import MobileNetV2
from torch.utils import data
from typing import Optional, Callable, List, Tuple
from horizon_plugin_pytorch.functional import rgb2centered_yuv
import torch.quantization
from horizon_plugin_pytorch.march import set_march
from horizon_plugin_pytorch.quantization import (
QuantStub,
prepare,
set_fake_quantize,
FakeQuantState,
)
from horizon_plugin_pytorch.quantization.qconfig_template import (
default_calibration_qconfig_setter,
default_qat_qconfig_setter,
)
from hbdk4 import compiler as hb4
import logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name: str, fmt=":f"):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)
def accuracy(output: Tensor, target: Tensor, topk=(1,)) -> List[Tensor]:
"""Computes the accuracy over the k top predictions for the specified
values of k
"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].float().sum()
res.append(correct_k.mul_(100.0 / batch_size))
return res
def evaluate(
model: nn.Module, data_loader: data.DataLoader, device: torch.device
) -> Tuple[AverageMeter, AverageMeter]:
top1 = AverageMeter("Acc@1", ":6.2f")
top5 = AverageMeter("Acc@5", ":6.2f")
with torch.no_grad():
for image, target in data_loader:
image, target = image.to(device), target.to(device)
output = model(image)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
top1.update(acc1, image.size(0))
top5.update(acc5, image.size(0))
print(".", end="", flush=True)
print()
return top1, top5
def train_one_epoch(
model: nn.Module,
criterion: Callable,
optimizer: torch.optim.Optimizer,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
data_loader: data.DataLoader,
device: torch.device,
) -> None:
top1 = AverageMeter("Acc@1", ":6.3f")
top5 = AverageMeter("Acc@5", ":6.3f")
avgloss = AverageMeter("Loss", ":1.5f")
model.to(device)
for image, target in data_loader:
image, target = image.to(device), target.to(device)
output = model(image)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if scheduler is not None:
scheduler.step()
acc1, acc5 = accuracy(output, target, topk=(1, 5))
top1.update(acc1, image.size(0))
top5.update(acc5, image.size(0))
avgloss.update(loss, image.size(0))
print(".", end="", flush=True)
print()
print(
"Full cifar-10 train set: Loss {:.3f} Acc@1"
" {:.3f} Acc@5 {:.3f}".format(avgloss.avg, top1.avg, top5.avg)
)
Building Floating-point Model
First, after obtaining the floating-point model, we need to make necessary modifications to the floating-point model to enable it to support quantization-related operations. The necessary operations for model modification include:
- Insert
QuantStub node before the model inputs.
- Insert
DequantStub node before the model inputs.
Attention
Be sure to insert QuantStub and DequantStub node at the boundary between non-deployment logic and deployment logic. For example, the head that calculates the auxiliary loss does not need to be deployed. Therefore, a DequantStub should be inserted at the input of this head to ensure that this head is not quantized.
Attention is needed when remodeling the model:
- The inserted
QuantStub and DequantStub nodes must be registered as submodules of the model, otherwise their quantized state will not be handled correctly.
- Multiple inputs can share
QuantStub only if the scale is the same, otherwise define a separate QuantStub for each input.
- If you need to specify the source of the data entered during board up as
"pyramid", please manually set the scale parameter of the corresponding QuantStub to 1/128.
- It is also possible to use
torch.quantization.QuantStub, but only horizon_plugin_pytorch.quantization.QuantStub supports manually fixing the scale with the parameter.
The modified model can seamlessly load the parameters of the pre-modified model, so if there is an existing trained floating-point model, it can be loaded directly, otherwise you need to do floating-point training normally.
Attention
The input image data is typically in centered_yuv444 format when the model is on board, so the image needs to be converted to centered_yuv444 format when the model is trained (note the use of rgb2centered_yuv in the code below).
If it is not possible to convert to centered_yuv444 format for model training, please insert the appropriate color space conversion node on the input when the model is deployed. (Note that this method may result in lower model accuracy)
The example in this section has fewer floating-point and QAT training epochs, just to illustrate the process of using the training tool, and the accuracy does not represent the best level of the model.
Note
Here is a simple multi-input and multi-output example. Note that each QuantStub is responsible for quantizing one input Tensor.
class Net(torch.nn.Module):
def __init__(self):
self.quantx = QuantStub()
self.quanty = QuantStub()
self.dequant = DequantStub()
...
def forward(self, x, y):
x = self.quantx(x)
y = self.quanty(y)
...
ret_0 = self.dequant(ret_0)
ret_1 = self.dequant(ret_1)
return ret_0, ret_1
######################################################################
# The user can modify the following parameters as required.
# 1. Save paths for model ckpt and compiled outputs.
model_path = "model/mobilenetv2"
# 2. Download the dataset and save the path.
data_path = "data"
# 3. The batch_size used for training.
train_batch_size = 256
# 4. The batch_size used for prediction.
eval_batch_size = 256
# 5. Number of epochs trained.
epoch_num = 10
# 6. The device used for model saving and performing calculations.
device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
######################################################################
# To prepare the dataset, note the use of rgb2centered_yuv in collate_fn.
def prepare_data_loaders(
data_path: str, train_batch_size: int, eval_batch_size: int
) -> Tuple[data.DataLoader, data.DataLoader]:
normalize = transforms.Normalize(mean=0.0, std=128.0)
def collate_fn(batch):
batched_img = torch.stack(
[
torch.from_numpy(np.array(example[0], np.uint8, copy=True))
for example in batch
]
).permute(0, 3, 1, 2)
batched_target = torch.tensor([example[1] for example in batch])
batched_img = rgb2centered_yuv(batched_img)
batched_img = normalize(batched_img.float())
return batched_img, batched_target
train_dataset = CIFAR10(
data_path,
True,
transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandAugment(),
]
),
download=True,
)
eval_dataset = CIFAR10(
data_path,
False,
download=True,
)
train_data_loader = data.DataLoader(
train_dataset,
batch_size=train_batch_size,
sampler=data.RandomSampler(train_dataset),
num_workers=8,
collate_fn=collate_fn,
pin_memory=True,
)
eval_data_loader = data.DataLoader(
eval_dataset,
batch_size=eval_batch_size,
sampler=data.SequentialSampler(eval_dataset),
num_workers=8,
collate_fn=collate_fn,
pin_memory=True,
)
return train_data_loader, eval_data_loader
# Make the necessary modifications to the floating point model.
class QATReadyMobileNetV2(MobileNetV2):
def __init__(
self,
num_classes: int = 10,
width_mult: float = 1.0,
inverted_residual_setting: Optional[List[List[int]]] = None,
round_nearest: int = 8,
):
super().__init__(
num_classes, width_mult, inverted_residual_setting, round_nearest
)
self.quant = QuantStub()
self.dequant = DeQuantStub()
def forward(self, x: Tensor) -> Tensor:
x = self.quant(x)
x = super().forward(x)
x = self.dequant(x)
return x
if not os.path.exists(model_path):
os.makedirs(model_path, exist_ok=True)
# Floating-point model initialization.
float_model = QATReadyMobileNetV2()
# Prepare the dataset
train_data_loader, eval_data_loader = prepare_data_loaders(
data_path, train_batch_size, eval_batch_size
)
# Since the last layer of the model is inconsistent with the pre-trained model, a floating point finetune is required.
optimizer = torch.optim.Adam(
float_model.parameters(), lr=0.001, weight_decay=1e-3
)
best_acc = 0
for nepoch in range(epoch_num):
float_model.train()
train_one_epoch(
float_model,
nn.CrossEntropyLoss(),
optimizer,
None,
train_data_loader,
device,
)
# Floating-point Precision Test.
float_model.eval()
top1, top5 = evaluate(float_model, eval_data_loader, device)
print(
"Float Epoch {}: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format(
nepoch, top1.avg, top5.avg
)
)
if top1.avg > best_acc:
best_acc = top1.avg
# Save optimal floating-point model parameters.
torch.save(
float_model.state_dict(),
os.path.join(model_path, "float-checkpoint.ckpt"),
)
Files already downloaded and verified
Files already downloaded and verified
....................................................................................................................................................................................................
Full cifar-10 train set: Loss 2.156 Acc@1 19.356 Acc@5 68.370
........................................
Float Epoch 0: evaluation Acc@1 30.970 Acc@5 84.260
...
....................................................................................................................................................................................................
Full cifar-10 train set: Loss 1.184 Acc@1 58.172 Acc@5 94.614
........................................
Float Epoch 9: evaluation Acc@1 63.040 Acc@5 95.940
Calibration
After the model is transformed and the floating point training is completed, Calibration can be performed. This process is done by inserting Observer in the model and counting the distribution of data in each place during the forward process, so as to calculate a reasonable quantization parameter:
- For part of the model, the accuracy can be achieved by Calibration only, without the need for the more time-consuming quantized awareness training.
- Even if the model cannot meet the accuracy requirements after quantization calibration, this process can reduce the difficulty of subsequent quantization awareness training, shorten the training time, and improve the final training accuracy.
######################################################################
# The user can modify the following parameters as required.
# 1. The batch_size used for Calibration.
calib_batch_size = 256
# 2. The batch_size used for Validation.
eval_batch_size = 256
# 3. The amount of data used by Calibration, configured to inf to use all data.
num_examples = float("inf")
# 4. Code name of the target hardware platform.
march = "nash-e"
# 5. Example input for model tracing and export HBIR.
example_input = torch.rand(1, 3, 32, 32, device=device)
######################################################################
# Before model transformation, the hardware platform on which the model will be executed must be set up.
set_march(march)
# Transform the model into the Calibration state to characterize the numerical distribution of the data at each location statistically.
calib_model = prepare(
float_model, example_input, default_calibration_qconfig_setter
)
# Prepare the dataset.
calib_data_loader, eval_data_loader = prepare_data_loaders(
data_path, calib_batch_size, eval_batch_size
)
# Perform Calibration process (no backward required).
# Note the control of the model state here, the model needs to be in the eval state for the behavior of Bn to match the requirements.
calib_model.eval()
set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)
with torch.no_grad():
cnt = 0
for image, target in calib_data_loader:
image, target = image.to(device), target.to(device)
calib_model(image)
print(".", end="", flush=True)
cnt += image.size(0)
if cnt >= num_examples:
break
print()
# Test pseudo-quantization accuracy.
# Note the control of the model state here.
calib_model.eval()
set_fake_quantize(calib_model, FakeQuantState.VALIDATION)
top1, top5 = evaluate(
calib_model,
eval_data_loader,
device,
)
print(
"Calibration: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format(
top1.avg, top5.avg
)
)
# Saving Calibration Model Parameters.
torch.save(
calib_model.state_dict(),
os.path.join(model_path, "calib-checkpoint.ckpt"),
)
INFO: The qconfig of classifier.1 will be set to default_qat_8bit_weight_32bit_out_fake_quant_qconfig
INFO: Template qconfig has been set!
INFO: Begin check qat model...
INFO: All fusable modules are fused in model!
INFO: All modules in the model run exactly once.
WARNING: Please check these modules qconfig if expected:
+---------------+---------------------------------------------------------+-----------------------------------------+
| module name | module type | msg |
|---------------+---------------------------------------------------------+-----------------------------------------|
| quant | <class 'horizon_plugin_pytorch.nn.qat.stubs.QuantStub'> | Fixed scale 0.0078125 |
| classifier.1 | <class 'horizon_plugin_pytorch.nn.qat.linear.Linear'> | activation is None. Maybe output layer? |
+---------------+---------------------------------------------------------+-----------------------------------------+
INFO: Check full result in ./model_check_result.txt
INFO: End check
Files already downloaded and verified
Files already downloaded and verified
....................................................................................................................................................................................................
........................................
Calibration: evaluation Acc@1 62.740 Acc@5 95.960
If the quantization accuracy of the model after Calibration meets the requirements, the Model Deployment step can be carried out directly, otherwise the Quantized Awareness Training needs to be carried out to further improve the accuracy.
Quantized Awareness Training
The quantized awareness training makes the model aware of the impact of quantization during the training process by inserting pseudo-quantization nodes in the model, in this case fine-tuning the model parameters in order to improve the accuracy after quantization.
Note
Adaptation for multi-machine and multi-GPU
Both the calibration and quantization-aware training processes support multi-machine and multi-GPU. The task launching method is exactly the same as that for floating-point models. In each process, it is sufficient to perform the preparation first before encapsulating the model with torch.nn.parallel.DistributedDataParallel.
######################################################################
# The user can modify the following parameters as required.
# 1. The batch_size used for training.
train_batch_size = 256
# 2. The batch_size used for Validation.
eval_batch_size = 256
# 3. Number of epochs trained.
epoch_num = 3
######################################################################
# Prepare the dataset.
train_data_loader, eval_data_loader = prepare_data_loaders(
data_path, train_batch_size, eval_batch_size
)
# Convert the model to QAT state.
qat_model = prepare(float_model, example_input, default_qat_qconfig_setter)
# Load Quantization Parameters in Calibration Models.
qat_model.load_state_dict(calib_model.state_dict())
# Conduct quantized awareness training.
# As a filetune process, quantized awareness training generally requires a small learning rate to be set.
optimizer = torch.optim.Adam(
qat_model.parameters(), lr=1e-3, weight_decay=1e-4
)
best_acc = 0
for nepoch in range(epoch_num):
# Note the method of controlling the training state of the QAT model here.
qat_model.train()
set_fake_quantize(qat_model, FakeQuantState.QAT)
train_one_epoch(
qat_model,
nn.CrossEntropyLoss(),
optimizer,
None,
train_data_loader,
device,
)
# Note the method of controlling the eval state of the QAT model here.
qat_model.eval()
set_fake_quantize(qat_model, FakeQuantState.VALIDATION)
top1, top5 = evaluate(
qat_model,
eval_data_loader,
device,
)
print(
"QAT Epoch {}: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format(
nepoch, top1.avg, top5.avg
)
)
if top1.avg > best_acc:
best_acc = top1.avg
torch.save(
qat_model.state_dict(),
os.path.join(model_path, "qat-checkpoint.ckpt"),
)
Files already downloaded and verified
Files already downloaded and verified
INFO: The qconfig of classifier.1 will be set to default_qat_8bit_weight_32bit_out_fake_quant_qconfig
INFO: Template qconfig has been set!
INFO: Begin check qat model...
INFO: All fusable modules are fused in model!
INFO: All modules in the model run exactly once.
WARNING: Please check these modules qconfig if expected:
+---------------+---------------------------------------------------------+-----------------------------------------+
| module name | module type | msg |
|---------------+---------------------------------------------------------+-----------------------------------------|
| quant | <class 'horizon_plugin_pytorch.nn.qat.stubs.QuantStub'> | Fixed scale 0.0078125 |
| classifier.1 | <class 'horizon_plugin_pytorch.nn.qat.linear.Linear'> | activation is None. Maybe output layer? |
+---------------+---------------------------------------------------------+-----------------------------------------+
INFO: Check full result in ./model_check_result.txt
INFO: End check
....................................................................................................................................................................................................
Full cifar-10 train set: Loss 1.267 Acc@1 55.638 Acc@5 93.620
........................................
QAT Epoch 0: evaluation Acc@1 63.500 Acc@5 96.540
...
....................................................................................................................................................................................................
Full cifar-10 train set: Loss 1.111 Acc@1 60.960 Acc@5 95.274
........................................
QAT Epoch 2: evaluation Acc@1 67.160 Acc@5 97.200
Model Deployment
After the pseudo-quantization accuracy meets the standard through the series of processes described above, the relevant processes for model deployment can be executed, mainly including operations such as exporting the HBIR model (export) and converting to a fixed-point model (convert).
Export HBIR Model
Model deployment requires exporting the pseudo-quantization model as a HBIR model firstly.
Attention
- The batch_size of the example_input used in model export determines the batch_size for model simulation and model uploading, if you need to use different batch_size for simulation and uploading, please use different data to export HBIR model separately.
- You can also skip the actual calibration and training process in Calibration and Quantization Awareness Training and go directly to the model deployment process first to ensure that there are no operations in the model that cannot be exported or compiled.
Note
Accuracy Description
The exported hbir model is theoretically consistent with the torch model in terms of calculation logic, but there are the following differences that lead to numerical inconsistency:
-
Most nonlinear elementwise operators in the torch model are converted to a lookup table implementation in hbir.
-
Operators with accumulation calculations, such as reduce_sum and gemm, will cause numerical fluctuations due to different accumulation orders.
######################################################################
# The user can modify the following parameters as required.
# 1. Which model to use as input for the process, you can choose either calib_model or qat_model.
base_model = qat_model
######################################################################
from horizon_plugin_pytorch.quantization.hbdk4 import export
qat_model.eval()
set_fake_quantize(qat_model, FakeQuantState.VALIDATION)
hbir_qat_model = export(base_model, (example_input,))
INFO: Model ret: Tensor(shape=(1, 10), dtype=torch.float32, device=cuda:0)
Convert to Fixed-point Model
After exporting the model to the HBIR model, the model can be converted to a fixed-point model. The results of the fixed-point model are generally considered to be identical to those of the compiled model.
Attention
- HBIR models support only a single
Tensor or Tuple[Tensor] as input, and only Tuple[Tensor] as output.
- It is not possible to achieve complete numerical agreement between the fixed-point model and the pseudo-quantization model, so please take the accuracy of the fixed-point model as the standard. If the fixed-point accuracy is not up to standard, you need to continue the quantized awareness training.
Note
Accuracy Description
After convert, the model changes from floating-point pseudo-quantization computation to int computation, which will lead to numerical fluctuations.
# Transform the model to a fixed-point state, note that the march here needs to be distinguished from nash-e/m/p.
hbir_quantized_model = hb4.convert(
hbir_qat_model,
"nash-e",
)
# Dataloader for test accuracy of HBIR model. Please note that the batch size
# should be same as the example input when exporting HBIR.
_, eval_hbir_data_loader = prepare_data_loaders(
data_path, train_batch_size, 1
)
def evaluate_hbir(
model: hb4.Module, data_loader: data.DataLoader
) -> Tuple[AverageMeter, AverageMeter]:
top1 = AverageMeter("Acc@1", ":6.2f")
top5 = AverageMeter("Acc@5", ":6.2f")
for image, target in data_loader:
image, target = image.cpu(), target.cpu()
# Default inpujt/output names are _input_{n}, _output_{n},users can
# modify them by params when export HBIR.
output = model["forward"].feed({"_input_0": image})["_output_0"]
acc1, acc5 = accuracy(output, target, topk=(1, 5))
top1.update(acc1, image.size(0))
top5.update(acc5, image.size(0))
return top1, top5
# Test the accuracy of fixed-point models.
top1, top5 = evaluate_hbir(
hbir_quantized_model,
eval_hbir_data_loader,
)
print(
"Quantized model: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format(
top1.avg, top5.avg
)
)
Files already downloaded and verified
Files already downloaded and verified
Quantized model: evaluation Acc@1 65.620 Acc@5 93.000
(Optional) Model Modification
Before the model compilation, we also support modify the board-side deployable model, the common operations and the calling API interfaces are as follows:
-
After the model export and before convert:
-
Batch input splitting is performed by calling the insert_split() interface.
-
Insert the image preprocessing node to the model:
a. The layout needs to be adjusted to NHWC for subsequent operations, which is performed by calling the insert_transpose() interface.
b. Image normalization is performed by calling the insert_image_preprocess() interface.
c. Color conversion (typically nv12 input for board-side deployment) is performed by calling the insert_image_convert() interface.
d. Configuration of the input as a resizer input to support roi-based keying and scaling is performed by calling the insert_roi_resize() interface.
-
Adjustment the input and output data layout is performed by calling the insert_transpose() interface.
-
After the model convert and before compile, remove operators (Quantize/Dequantize/Cast, etc.) by calling the remove_io_op() interface.
For details, you can refer to section HBDK Tool API Reference for the above mentioned APIs.
Model Compilation
After testing the accuracy of the fixed-point model and confirming that it meets the requirements, the model can be compiled, performance tested and visualized.
Attention
The model used for perf should be calibrated at least once (with no limit on the number of steps) to ensure that the statistics in the model match the actual situation, otherwise it may cause inaccurate perf results.
######################################################################
# The user can modify the following parameters as required.
# 1. The level of optimization enabled at compile time, the higher the level the faster the compiled model will be executed on the board, but the compilation process will be slower.
compile_opt = 1
######################################################################
# Model compile.
hb4.compile(
hbir_quantized_model,
os.path.join(model_path, "model.hbm"),
"nash-e",
opt=compile_opt,
)
# Model perf.
hb4.hbm_perf(
os.path.join(model_path, "model.hbm"),
output_dir=model_path,
)
[10h:58m:15s:654039us INFO hbrt4_loader::parsing] pid:212735 tid:212735 hbrt4_loader/src/parsing.rs:42: Load hbm header from file; filename="model/mobilenetv2/model.hbm"
[10h:58m:15s:655241us INFO hbrt4_log::logger] pid:212735 tid:212735 hbrt4_log/src/logger.rs:388: Logger of HBRT4 initialized, version = 4.1.2
[10h:58m:15s:655253us INFO hbrt4_loader::parsing] pid:212735 tid:212735 hbrt4_loader/src/parsing.rs:73: Load hbm from file; filename="model/mobilenetv2/model.hbm"
FPS=11518.08, latency = 86.799999999999997 us, DDR = 2597376 bytes (see model/mobilenetv2/forward.html)
HBDK hbm perf SUCCESS
# Model Visualization.
hb4.visualize(hbir_quantized_model, "mobilenetv2_cifar10.onnx")
Temporary onnx file saved to mobilenetv2_cifar10.onnx