import copy
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.quantization import DeQuantStub, QuantStub
from horizon_plugin_pytorch import March, set_march
from horizon_plugin_pytorch.fx.jit_scheme import Tracer
from horizon_plugin_pytorch.quantization import (
FakeQuantState,
get_qconfig,
PrepareMethod,
prepare,
set_fake_quantize,
)
class Net(torch.nn.Module):
def __init__(self, input_size, class_num) -> None:
super().__init__()
self.quant0 = QuantStub()
self.quant1 = QuantStub()
self.dequant = DeQuantStub()
self.conv = nn.Conv2d(3, 3, 1)
self.bn = nn.BatchNorm2d(3)
self.classifier = nn.Conv2d(3, class_num, input_size)
self.loss = nn.CrossEntropyLoss()
def forward(self, input, other, target=None):
# 不需要量化的前处理,使用 JIT_STRIP 时,将这些操作从计算图中剔除。
input = (input - 128) / 128.0
x = self.quant0(input)
y = self.quant1(other)
n = np.random.randint(1, 5)
m = np.random.randint(1, 5)
# 由于不重新生成 python code,此动态循环在 QAT 模型中保留。
for _ in range(n):
for _ in range(m):
# 动态循环中的代码块涉及到算子替换或算子融合时,必须进行标注。
# 标注的是需要算子替换或算子融合的逻辑,而不是 for 循环。
with Tracer.dynamic_block(self, "ConvBnAdd"):
x = self.conv(x)
x = self.bn(x)
x = x + y
x = self.classifier(x).squeeze()
# 由于不重新生成 python code,此动态控制流在 QAT 模型中保留
if self.training:
assert target is not None
x = self.dequant(x)
return F.cross_entropy(torch.softmax(x, dim=1), target)
else:
return torch.argmax(x, dim=1)
model = Net(6, 2)
train_example_input = (
torch.rand(2, 3, 6, 6) * 256,
torch.rand(2, 3, 6, 6),
torch.tensor([[0.0, 1.0], [1.0, 0.0]]),
)
eval_example_input = train_example_input[:2]
model.eval()
set_march("nash-e")
model.qconfig = get_qconfig()
qat_model = prepare(
model,
example_inputs=copy.deepcopy(eval_example_input),
method=PrepareMethod.JIT_STRIP,
)
qat_model.graph.print_tabular()
# opcode name target args kwargs
# ------------- ---------------- --------------------------------------------------------- -------------------------------- ----------
# placeholder input_0 input_0 () {}
# call_module quant0 quant0 (input_0,) {}
# placeholder input_1 input_1 () {}
# call_module quant1 quant1 (input_1,) {}
# call_module conv conv (quant0,) {}
# call_module bn bn (conv,) {}
# get_attr _generated_add_0 _generated_add_0 () {}
# call_method add_2 add (_generated_add_0, bn, quant1) {}
# scope_end 是在 trace 过程中自动插入的,用于标记子 module 或动态代码块的边界,不对应实际计算
# call_function scope_end <function Tracer.scope_end at 0x7f65d90e5e50> ('_dynamic_block_ConvBnAdd',) {}
# call_module conv_1 conv (add_2,) {}
# call_module bn_1 bn (conv_1,) {}
# get_attr _generated_add_1 _generated_add_0 () {}
# call_method add_3 add (_generated_add_1, bn_1, quant1) {}
# call_function scope_end_1 <function Tracer.scope_end at 0x7f65d90e5e50> ('_dynamic_block_ConvBnAdd',) {}
# call_module classifier classifier (add_3,) {}
# call_function squeeze <method 'squeeze' of 'torch._C._TensorBase' objects> (classifier,) {}
# call_function argmax <built-in method argmax of type object at 0x7f66f04cf820> (squeeze,) {'dim': 1}
# call_function scope_end_2 <function Tracer.scope_end at 0x7f65d90e5e50> ('',) {}
# output output output ((argmax,),) {}
print(qat_model)
# GraphModuleImpl(
# (quant0): QuantStub(
# (activation_post_process): FakeQuantize(
# fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0])
# (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([]))
# )
# )
# (quant1): QuantStub(
# (activation_post_process): FakeQuantize(
# fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0])
# (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([]))
# )
# )
# (dequant): DeQuantStub()
# (conv): Identity() # 由于 forward 代码不变,conv 和 bn 仍将被执行,所以融合后必须将 Conv 和 Bn 替换为 Identity
# (bn): Identity()
# (classifier): Conv2d(
# 3, 2, kernel_size=(6, 6), stride=(1, 1)
# (activation_post_process): FakeQuantize(
# fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0])
# (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([]))
# )
# (weight_fake_quant): FakeQuantize(
# fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, scale=tensor([1., 1.]), zero_point=tensor([0, 0])
# (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([]))
# )
# )
# (loss): CrossEntropyLoss()
# (_generated_add_0): ConvAdd2d( # 自动将 '+' 替换为 Module 形式,并将 Conv 和 Bn 融合进来
# 3, 3, kernel_size=(1, 1), stride=(1, 1)
# (activation_post_process): FakeQuantize(
# fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0])
# (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([]))
# )
# (weight_fake_quant): FakeQuantize(
# fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, scale=tensor([1., 1., 1.]), zero_point=tensor([0, 0, 0])
# (activation_post_process): MinMaxObserver(min_val=tensor([]), max_val=tensor([]))
# )
# )
# )
qat_model.train()
set_fake_quantize(qat_model, FakeQuantState.QAT)
for _ in range(3):
ret = qat_model(*train_example_input)
ret.backward()