6.4.3. 开发指南

6.4.3.1. 浮点模型的要求

symbolic_trace

和 PyTorch 的量化训练类似,horizon_plugin_pytorch 基于 fx 设计和开发,因此,要求浮点模型必须是可以正确的完成 symbolic_trace 的

仅支持部分算子

由于 BPU 只支持数量有限的算子,因此,horizon_plugin_pytorch 只支持算子列表中的算子和基于 BPU 限制而内部特殊定义的特殊算子。

构建量化友好模型

浮点模型变为定点模型的过程存在一定的精度误差,越是量化友好的浮点模型, qat 精度提升越容易,量化后的精度也越高。一般而言,有以下几种情况会导致模型变得量化不友好:

  1. 使用有精度风险的算子。例如: softmax , layernorm 等(详见 op 文档),这类算子一般底层由查表或多个 op 拼接实现,容易发生掉点问题。

  2. 一次 forward 中多次调用同一算子。同一算子多次调用,对应的输出分布存在差异,但只会统计一组量化参数,当多次调用的输出分布差异过大时,量化误差会变大。

  3. add , cat 等多输入算子的不同输入差异过大,可能造成较大误差。

  4. 数据分布不合理。plugin 采用的是均匀对称量化,所以 0 均值的均匀分布最好,应尽量避免长尾和离群点。同时,数值范围需要与量化 bit 相匹配,如果使用int8量化分布为 [-1000, 1000] 均匀分布的数据,那么精度显然也是不够的。例如,下面三个分布图,从左到右对量化的友好性依次递减,模型中大部分数值的分布应当为中间这种分布。在实际使用中,可以用 debug 工具查看模型 weight 和 feature map 的分布是否量化友好。因为模型冗余性的存在,有些看起来分布非常量化不友好的 op 并不会显著降低模型的最终精度,需要结合实际的 qat 训练难度和最后达到的量化精度综合考虑。

data_distribution

那么如何使得模型更加量化友好呢?具体来说:

  1. 尽量少使用精度风险过大的算子,详见 op 文档。

  2. 保证多次调用的共享算子每次调用的输出分布差异不要太大,或者将共享算子拆开分别单独使用。

  3. 避免多输入算子不同输入的数值范围差异过大。

  4. 使用 int16 量化数值范围和误差都非常大的 op 。可通过 debug 工具找到这类 op 。

  5. 通过调大 weight decay ,增加数据增强等方式防止模型过拟合。过拟合模型容易出现较大数值,且对输入非常敏感,轻微的误差可能导致输出完全错误。

  6. 使用 BN 。

  7. 对模型输入做关于0对称的归一化。

需要注意的是, qat 自身具有一定的调整能力,量化不友好并不代表不能量化,很多情况下,即使出现上面的不适合量化的现象,仍然可以量化得很好。因为上述建议也可能会导致浮点模型精度下降,所以应当在 qat 精度无法达标时再尝试上述建议,尤其是 1 - 5 条建议,最后应当是在浮点模型精度和量化模型精度中找一个平衡点。

6.4.3.2. qconfig 详解

什么是 qconfig

模型的量化方式由 qconfig 决定,在准备 qat / calibration 模型之前,需要先给模型设置 qconfig。我们不推荐您自定义 qconfig,尽量只使用预定义好的qconfig变量,因为自定义 qconfig 需要对具体的处理器限制认知清晰,详细了解训练工具的工作原理,定义出错可能导致模型无法正常收敛、模型无法编译等问题,浪费大量时间和人力。

目前,Plugin 中维护了两个版本的qconfig,早期版本的 qconfig 将在不久的将来被废弃,我们只推荐您使用此文档中介绍的 qconfig 用法。

如何获取 qconfig

  1. 使用封装好的 qconfig 变量。这些 qconfig 存放在 horizon_plugin_pytorch/quantization/qconfig.py 中,可以适用于绝大多数情况。包括:

from horizon_plugin_pytorch.quantization.qconfig import (
    default_calib_8bit_fake_quant_qconfig,
    default_qat_8bit_fake_quant_qconfig,
    default_qat_8bit_fixed_act_fake_quant_qconfig,
    default_calib_8bit_weight_16bit_act_fake_quant_qconfig,
    default_qat_8bit_weight_16bit_act_fake_quant_qconfig,
    default_qat_8bit_weight_16bit_fixed_act_fake_quant_qconfig,
    default_qat_8bit_weight_32bit_out_fake_quant_qconfig, # 参考算子列表,支持高精度输出的算子可以设置此 qconfig 获得更高的精度
    default_calib_8bit_weight_32bit_out_fake_quant_qconfig, # 参考算子列表,支持高精度输出的算子可以设置此 qconfig 获得更高的精度
)
  1. 使用 get_default_qconfig 接口。此接口较固定 qconfig 变量更灵活,我们推荐您对量化和硬件限制有清晰认知之后再使用。常用参数和解释如下:

from horizon_plugin_pytorch.quantization.qconfig import get_default_qconfig

qconfig = get_default_qconfig(
    activation_fake_quant="fake_quant",  # 支持 fake_quant, lsq, pact,常用 fake quant
    weight_fake_quant="fake_quant", # 支持 fake_quant, lsq, pact,常用 fake quant
    activation_observer="min_max", # 支持 min_max, fixed_scale, clip, percentile, clip_std, mse, kl
    weight_observer="min_max", # 支持 min_max, fixed_scale, clip, percentile, clip_std, mse, kl
    activation_qkwargs={
        "dtype": qint16, # 由具体算子决定是否支持 int16
        "is_sync_quantize": False, # 是否同步统计数据,默认关闭提升forward速度
        "averaging_constant": 0.01 # 滑动平均系数,设置为0时,scale不更新
    },
    weight_qkwargs={ # 只支持 dtype = qint8, qscheme = torch.per_channel_symmetric, ch_axis = 0, 不建议做额外配置
        "dtype": qint8,
        "qscheme": torch.per_channel_symmetric,
        "ch_axis": 0,
    },
)

如何设置 qconfig

共有三种设置方法,我们推荐您使用前两种,最后一种设置方式将废弃。

  1. 直接设置 qconfig 属性。此方法优先级最高,其余方法不会覆盖直接设置的 qconfig。

model.qconfig = default_qat_8bit_fake_quant_qconfig
  1. qconfig 模板。在 prepare 接口上指定 qconfig setter 和 example_inputs,自动为模型设置 qconfig。

model = prepare_qat_fx(
    model,
    example_inputs=data,
    qconfig_setter=default_qat_qconfig_setter,
)
  1. qconfig_dict。在 prepare_qat_fx 接口上指定 qconfig_dict。此用法将逐步废弃,如无兼容性需求,不推荐再使用,这里不展开介绍。

model = prepare_qat_fx(
    model,
    qconfig_dict={"": default_qat_qconfig_setter},
)

qconfig 模板

长期以来,配置 qconfig 出错的问题经常发生,因此我们开发了 qconfig 模板。qconfig 模板基于 subclass trace 方案感知模型的图结构,并按设定的规则自动设置 qconfig,是我们最推荐的设置 qconfig 方法。用法如下:

qat_model = prepare_qat_fx(
    model,
    example_inputs=example_input,  # 用来感知图结构
    qconfig_setter=( # qconfig 模板,支持传入多个模板,优先级从高到低。
        sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter(table, ratio=0.2),
        default_calibration_qconfig_setter,
    )
)
模板的优先级低于直接给模型设置 qconfig 属性,如果模型在 prepare 之前已经使用 model.qconfig = xxx 进行了配置,那么模板将不会生效。如果没有特殊需求,我们不推荐将两者混合使用,这很容易引发低级错误。绝大多数情况下,我们推荐您使用模板和 model.qconfig = xxx 两种设置方式中的一种即可满足需求。

模板可分为三类:

  1. 固定模板。固定模板中 calibration / qat / qat_fixed_act_scale 区别在于使用的 observer 类型和 scale 更新逻辑,分别用于校准,qat 训练,固定 activation scale qat 训练。default 模板( default_calibration_qconfig_setter / default_qat_qconfig_setter / default_qat_fixed_act_qconfig_setter )会做三件事:首先,将可以设置的高精度输出都设置上,对于不支持高精度的输出将给出提示;然后,从 grid sample 算子的 grid 输入向前搜索,直到出现第一个 gemm 类算子或者QuantStub,将中间的所有算子都设置为 int16。根据经验这里的 grid 一般表达范围较宽,int8 有较大可能不满足精度需求;最后,将其余算子设置为 int8。int16 模板( qat_8bit_weight_16bit_act_qconfig_setter / qat_8bit_weight_16bit_fixed_act_qconfig_setter / calibration_8bit_weight_16bit_act_qconfig_setter )会做两件事:首先,将可以设置的高精度输出都设置上,对于不支持高精度的输出将给出提示;其次,将其余算子设置为 int16。

from horizon_plugin_pytorch.quantization.qconfig_template import (
    default_calibration_qconfig_setter,
    default_qat_qconfig_setter,
    default_qat_fixed_act_qconfig_setter,
    qat_8bit_weight_16bit_act_qconfig_setter,
    qat_8bit_weight_16bit_fixed_act_qconfig_setter,
    calibration_8bit_weight_16bit_act_qconfig_setter,
)
  1. 敏感度模板。敏感度模板有 sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter, sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter, sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter,三者的区别和固定模板中三者的区别一致,也是分别用于校准,qat 训练,固定 activation scale qat 训练。 敏感度模板的第一个输入是精度 debug 工具产生的敏感度结果,第二个参数可以指定 ratio 或 topk ,敏感度模板会将量化敏感度最高的 topk 个算子设置为 int16。搭配固定模板,可以轻松实现混合精度调优。

from horizon_plugin_pytorch.quantization.qconfig_template import (
    default_calibration_qconfig_setter,
    default_qat_qconfig_setter,
    default_qat_fixed_act_qconfig_setter,
    qat_8bit_weight_16bit_act_qconfig_setter,
    qat_8bit_weight_16bit_fixed_act_qconfig_setter,
    calibration_8bit_weight_16bit_act_qconfig_setter,
    sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter,
    sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter,
    sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter,
)

table = torch.load("output_0-0_dataindex_1_sensitive_ops.pt")

qat_model = prepare_qat_fx(
    model,
    example_inputs=example_input,
    qconfig_setter=( 
        sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter(table, ratio=0.2),
        default_calibration_qconfig_setter,
    )
)
  1. 自定义模板。自定义模板只有 ModuleNameQconfigSetter,需要传入模块名和对应 qconfig 的字典,一般用于设置 fixed scale 等特殊需求,可以和固定模板,敏感度模板搭配使用。

from horizon_plugin_pytorch.quantization.qconfig_template import (
    default_calibration_qconfig_setter,
    default_qat_qconfig_setter,
    default_qat_fixed_act_qconfig_setter,
    qat_8bit_weight_16bit_act_qconfig_setter,
    qat_8bit_weight_16bit_fixed_act_qconfig_setter,
    calibration_8bit_weight_16bit_act_qconfig_setter,
    sensitive_op_qat_8bit_weight_16bit_act_qconfig_setter,
    sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter,
    sensitive_op_calibration_8bit_weight_16bit_act_qconfig_setter,
    ModuleNameQconfigSetter,
)

table = torch.load("output_0-0_dataindex_1_sensitive_ops.pt")

module_name_to_qconfig = {
    "op_1": default_qat_8bit_fake_quant_qconfig,
    "op_2": get_default_qconfig(
        activation_observer="fixed_scale",
        activation_qkwargs={
            "dtype": qint16,
            "scale": OP2_MAX / QINT16_MAX,
        },
    )
}

qat_model = prepare_qat_fx(
    model,
    example_inputs=example_input,
    qconfig_setter=(
        ModuleNameQconfigSetter(module_name_to_qconfig),
        sensitive_op_qat_8bit_weight_16bit_fixed_act_qconfig_setter(table, ratio=0.2),
        default_calibration_qconfig_setter,
    )
)

6.4.3.3. Calibration 指南

在量化中,一个重要的步骤是确定量化参数,合理的初始量化参数能够显著提升模型精度并加快模型的收敛速度。Calibration 就是在浮点模型中插入 Observer,使用少量训练数据,在模型 forward 过程中统计各处的数据分布,以确定合理的量化参数的过程。虽然不做 Calibration 也可以进行量化训练,但一般来说,它对量化训练有益无害,所以推荐用户将此步骤作为必选项。

流程和示例

Calibration 与 QAT 的整体流程如下图所示:

quick_start

下面分别介绍各个步骤:

  1. 构建并训练浮点模型。参考 horizon_plugin_pytorch 快速入门章节中的 获取浮点模型 小节内容。

  2. 在浮点模型上插入 Observer 节点。参考 horizon_plugin_pytorch 快速入门章节中的 Calibration 小节内容。使用 prepare_qat_fx 方法转化浮点模型前,需要为模型设置 qconfig

    model.qconfig = horizon.quantization.get_default_qconfig()
    

    get_default_qconfig 可以为 weightactivation 设置不同的 observer 。目前,calibration 可选 observer 有 “min_max”、 “percentile”、 “mse”、 “kl” 和 “mix”。如无特殊需求,weight_observer 推荐使用默认的 “min_max”,activation_observer 推荐使用 “mse”。特殊用法和调试技巧见下面的常见算法介绍。

    fake_quant 参数对 Calibration 结果无影响,保留默认状态即可。

    def get_default_qconfig(
        activation_fake_quant: Optional[str] = "fake_quant",
        weight_fake_quant: Optional[str] = "fake_quant",
        activation_observer: Optional[str] = "min_max",
        weight_observer: Optional[str] = "min_max",
        activation_qkwargs: Optional[Dict] = None,
        weight_qkwargs: Optional[Dict] = None,
    ):
    
  3. 设置 fake quantize 状态为 CALIBRATION

    horizon.quantization.set_fake_quantize(model, horizon.quantization.FakeQuantState.CALIBRATION)
    

    fake quantize 一共有三种状态,分别需要在 QATcalibrationvalidation 前将模型的 fake quantize 设置为对应的状态。在 calibration 状态下,仅观测各算子输入输出的统计量。在 QAT 状态下,除观测统计量外还会进行伪量化操作。而在 validation 状态下,不会观测统计量,仅进行伪量化操作。

    class FakeQuantState(Enum):
        QAT = "qat"
        CALIBRATION = "calibration"
        VALIDATION = "validation"
    
  4. calibration。把准备好的校准数据喂给模型,模型在 forward 过程中由 observer 观测相关统计量。

  5. 设置模型状态为 eval 并设置 fake quantize 状态为 VALIDATION

    model.eval()
    horizon.quantization.set_fake_quantize(model, horizon.quantization.FakeQuantState.VALIDATION)
    
  6. 验证 calibration 效果。如果效果满意,则可以直接将模型转为定点或在此基础上进行量化训练,不满意则调整 calibration qconfig 中的参数继续 calibration。

常用算法介绍

备注:

有关每个算子的参数说明,请参考文末 API 文档。
算法 速度排名 精度排名 易用性排名
min_max 1 5 1
percentile 2 4 4
mse 4 1 2
kl 5 2 3
mix 3 2 1

常用的几种校准方法性能如上表所示,数字越小越好,速度表示相同数据校准耗时,精度表示该方法在大多数模型上的校准效果,易用性表示该方法的调参复杂度。

对于同一模型而言,不同方法不同参数的精度/速度会存在较大差别,最新的一些研究工作也表明,没有一种方法可以在所有模型上都取得最好的精度,需要针对地调整其参数。所以推荐用户对这几种校准方法都进行尝试。

  1. min_max。此方法仅统计最大值最小值的滑动平均,用于快速确定 Batch size、average_constant 等通用参数,没有太多技巧。

  2. percentile。此方法是所有方法中精度上限最高的,但也是调整起来最麻烦的,如果通过其他方法或本方法的默认参数就可以满足精度要求,那么不建议在调参上花太多时间。percentile 可调的参数一共有两个 bins、percentile。bins 越多,max 的候选项间隔越小,可供调整的粒度越细,但也意味着更高的计算耗时。建议先确定 percentile 再调整 bins,两者交替迭代缩小调参范围直至达到满意的效果。绝大部分情况下 bins 取 2048 提供的调整粒度完全足够,不需要单独调整这个参数。以下是一个模型的调参路径:

顺序 percentile bins 精度
1 99.99 2048 53.75
2 99.99 4096 54.38
3 99.995 4096 16.25
4 99.985 4096 32.67
5 99.9875 4096 57.06
6 99.9875 8192 62.84
7 99.98875 8192 57.62
8 99.988125 8192 63.15

在这个例子中,可以看到仔细调整后,精度提升了大约 10%。 模型中不同 op 的输入输出之间存在很大差异,一组全局的 percentile 参数可能很难满足所有 op 的需求,对精度要求较高时,可以先通过上面的方法找到较好的全局参数,再通过 debug 工具找到误差较大的几个 op,单独为这几个 op 设置 percentile 参数,设置方式参照 qconfig 设置。下面列举几种常见的容易导致误差较大的数据分布:

calibration_percentile_longtail

超长尾分布,percentile 的取值应当小一些,图中 99.9 是较好的取值。

calibration_percentile_bimodal

值域过大,且分布并不集中在一处,这种情况无论是保留尾部还是忽略尾部都会带来较大的精度损失,应该在训练浮点模型时通过调整 weight decay 等参数避免这种情况的出现。

calibration_percentile_ln

layernorm 的输出分布会呈现出若干集中度非常高的区域,此时 percentile 按照正常方法调整对于量化结果不会有任何影响,需要将 percentile 调整幅度增加。

  1. mse。可调整的参数只有 stride,默认 stride 为 1,会逐步尝试最大值的 100 分位并选出量化反量化前后误差最小(L2 距离)的分位对应的值。此方法对大模型耗时较高,在合理范围内调大 stride 可以在保证精度的前提下减少耗时,stride 调整过大会影响精度。注意,调整此方法的参数只能优化耗时,并不能显著提升精度。

  2. kl。可调的参数一共有两个 bin 和 update_interval。由于此方法耗时过长,不建议调整默认 bin。update_interval 默认为 1,表示间隔多少个 forward step 计算一次 KL,调大可以减少耗时(不影响精度),但需要保证 update_interval 不超过总的 calibration step,否则无法得到正常的量化参数。一般推荐直接将 update_interval 设为 calibration step,这样前面的 forward step 只采集数据更新直方图,只有最后一个 step 才会计算 KL 和 scale,可以最大程度减少 KL 的耗时,同时由于最终的直方图包含所有输入数据的统计信息,因此不会对精度造成影响。

  3. mix。此方法为混合校准,对于每一个需要统计的地方,都会尝试 percentile 方法的不同参数,选出量化反量化前后误差最小(L2 距离)的方法。自动化程度较高,没有需要调整的参数。

调参技巧

  1. calibration 数据越多越好,但因为边际效应的存在,当数据量大到一定程度后,对精度的提升将非常有限。如果训练集较小,可以全部用来 calibration,如果训练集较大,可以结合 calibration 耗时挑选大小合适的子集,建议至少进行 10 - 100 个 step 的校准。

  2. 数据可以做水平翻转这类 augmentation,不要做马赛克这种 augmentation。尽量使用 infer 阶段的前处理 + 训练数据进行校准。

  3. Batch size 尽可能大,如果数据噪声较大或模型离群点较多,可以适当减小。此参数应当在尝试 min max 方法时确定。

  4. average_constant 表示每个 step 对最大值最小值的影响,average_constant 越小,当前 step 的影响越小,历史滑动均值的影响越大。该参数需要结合数据量在 0.01 ~ 0.5 之间调整。当数据量充足时(step > 100),average_constant 取 0.01,数据量不足时,average_constant 酌情增加,极端情况下,只有 2 个 step 的数据,average_constant 取 0.5。此参数应当在尝试 min max 方法时确定,之后其他方法都沿用此参数。

  5. calibration 模型精度较好时,固定 feature map 的量化参数进行 QAT 训练可以取得更好的效果,精度较差时,则不能固定 calibration 得到的量化参数。关于精度是好还是坏,没有明确的标准,需要去尝试。比如:某模型精度为 100,如果 calibration 精度为 50,那么精度肯定称不上好,但如果 calibration 精度为 95,那么这个精度是否可以达到固定 feature map 量化参数的程度就需要尝试了,通常做法是固定与不固定都做实验进行对比。

  6. 优先尝试 min max 方法,该方法是速度最快的,用来跑通 calibration 流程,调整并确定 batch size 和 average_constant 两个参数,接着分别尝试 percentile、kl、mse 和 mix 四种方法并选取效果最好的方法。

Observer 参数文档

class horizon_plugin_pytorch.quantization.observer_v2.KLObserver(
    bins: int = 512,
    update_interval: int = 1,
    averaging_constant: float = 0.01,
    ch_axis: int = -1,
    dtype: Union[torch.dtype, horizon_plugin_pytorch.dtype.QuantDType] = 'qint8',
    qscheme: torch.qscheme = torch.per_tensor_symmetric,
    quant_min: int = None,
    quant_max: int = None,
    is_sync_quantize: bool = False,
    factory_kwargs: Dict = None
)

KL 观察器(KLObserver)
基于直方图的 KL 散度观察器。直方图在线计算且不会保存。

参数:

  • bins – Number of histograms bins.

  • update_interval – Interval of computing KL entropy and update min/max. KLObserver will constantly collect histograms of activations, but only perform KL calculation when update_interval is satisfied. if it is set to 1, KL entropy will be computed every forward step. Larger interval guarantees less time and does no harm to calibration accuracy. Set it to the total calibration steps can achieve best performance. update_interval must be no greater than total calibration steps, otherwise no min/max will be computed.

  • averaging_constant – Averaging constant for min/max.

  • ch_axis – Channel axis.

  • dtype – Quantized data type.

  • qscheme – Quantization scheme to be used.

  • quant_min – Min quantization value. Will follow dtype if unspecified.

  • quant_max – Max quantization value. Will follow dtype if unspecified.

  • is_sync_quantize – If sync statistics when training with multiple devices.

  • factory_kwargs – kwargs which are passed to factory functions for min_val and max_val.

forward(x_orig)

定义每次调用时执行的计算。

所有子类都应重写此方法。

提示:
尽管前向传播的“配方”必须在这个函数里定义,但之后应该调用 Module 的实例,而不是直接调用这个函数,因为前者会运行已注册的钩子(hooks),而后者会静默地忽略它们。


class horizon_plugin_pytorch.quantization.observer_v2.MSEObserver(
    stride: int = 1,
    averaging_constant: float = 0.01,
    ch_axis: int = -1,
    dtype: Union[torch.dtype, horizon_plugin_pytorch.dtype.QuantDType] = 'qint8',
    qscheme: torch.qscheme = torch.per_tensor_symmetric,
    quant_min: int = None,
    quant_max: int = None,
    is_sync_quantize: bool = False,
    factory_kwargs: Dict = None
)

MSE 观察器(MSEObserver)

用于计算量化参数的观察器模块,基于原始张量与量化张量之间的均方误差(MSE)。

该观察器通过线性搜索最小化 MSE 的量化尺度。

参数:

  • stride – 搜索步长。值越大,搜索空间越小,计算时间越短,但精度可能下降。默认值为 1,建议不超过 20。

  • averaging_constant – 用于 min/max 的平滑系数。

  • ch_axis – 通道轴。

  • dtype – 量化后的数据类型。

  • qscheme – 使用的量化方案。

  • quant_min – 最小量化值。未指定时根据 dtype 自动推断。

  • quant_max – 最大量化值。未指定时根据 dtype 自动推断。

  • is_sync_quantize – 是否在使用多设备训练时同步统计信息。

  • factory_kwargs – 传递给 min_val 和 max_val 工厂函数的关键字参数。


forward(x_orig)

定义每次调用时执行的计算。

所有子类都应重写此方法。

提示:
尽管前向传播的“配方”必须在这个函数里定义,但之后应该调用 Module 的实例,而不是直接调用这个函数,因为前者会运行已注册的钩子(hooks),而后者会静默地忽略它们。


class horizon_plugin_pytorch.quantization.observer_v2.MinMaxObserver(
    averaging_constant: float = 0.01,
    ch_axis: int = -1,
    dtype: Union[torch.dtype, horizon_plugin_pytorch.dtype.QuantDType] = 'qint8',
    qscheme: torch.qscheme = torch.per_tensor_symmetric,
    quant_min: int = None,
    quant_max: int = None,
    is_sync_quantize: bool = False,
    factory_kwargs: Dict = None
)

MinMax 观察器(MinMaxObserver)

该观察器基于输入张量的最小值和最大值计算量化参数。模块会记录输入张量的滑动平均最小值和最大值,并使用这些统计量计算量化参数。

参数:

  • averaging_constant – 用于 min/max 的平滑系数。

  • ch_axis – 通道轴。

  • dtype – 量化后的数据类型。

  • qscheme – 使用的量化方案。

  • quant_min – 最小量化值。未指定时根据 dtype 自动推断。

  • quant_max – 最大量化值。未指定时根据 dtype 自动推断。

  • is_sync_quantize – 是否在使用多设备训练时同步统计信息。

  • factory_kwargs – 传递给 min_val 和 max_val 工厂函数的关键字参数。


forward(x_orig)

记录 x 的运行最小值和最大值。


class horizon_plugin_pytorch.quantization.observer_v2.MixObserver(
    averaging_constant: float = 0.01,
    ch_axis: int = -1,
    dtype: Union[torch.dtype, horizon_plugin_pytorch.dtype.QuantDType] = 'qint8',
    qscheme: torch.qscheme = torch.per_tensor_symmetric,
    quant_min: int = None,
    quant_max: int = None,
    is_sync_quantize: bool = False,
    factory_kwargs: Dict = None
)

Mix 观察器(MixObserver)

该观察器基于多种校准方法计算量化参数,并选择量化误差最小的参数。

参数:

  • averaging_constant – 用于 min/max 的平滑系数。

  • ch_axis – 通道轴。

  • dtype – 量化后的数据类型。

  • qscheme – 使用的量化方案。

  • quant_min – 最小量化值。未指定时根据 dtype 自动推断。

  • quant_max – 最大量化值。未指定时根据 dtype 自动推断。

  • is_sync_quantize – 是否在使用多设备训练时同步统计信息。

  • factory_kwargs – 传递给 min_val 和 max_val 工厂函数的关键字参数。


forward(x_orig)

定义每次调用时执行的计算。

所有子类都应重写此方法。

提示:
尽管前向传播的“配方”必须在这个函数里定义,但之后应该调用 Module 的实例,而不是直接调用这个函数,因为前者会运行已注册的钩子(hooks),而后者会静默地忽略它们。


class horizon_plugin_pytorch.quantization.observer_v2.PercentileObserver(
    percentile: float = 99.99,
    bins: int = 2048,
    averaging_constant: float = 0.01,
    ch_axis: int = -1,
    dtype: Union[torch.dtype, horizon_plugin_pytorch.dtype.QuantDType] = 'qint8',
    qscheme: torch.qscheme = torch.per_tensor_symmetric,
    quant_min: int = None,
    quant_max: int = None,
    is_sync_quantize: bool = False,
    factory_kwargs: Dict = None
)

百分位观察器(PercentileObserver)

基于直方图的百分位观察器。直方图在线计算且不会保存。最小值和最大值通过滑动平均计算,用于量化参数计算。

参数:

  • percentile – 直方图的百分位索引。

  • bins – 直方图的分桶数。

  • averaging_constant – 用于 min/max 的平滑系数。

  • ch_axis – 通道轴。

  • dtype – 量化后的数据类型。

  • qscheme – 使用的量化方案。

  • quant_min – 最小量化值。未指定时根据 dtype 自动推断。

  • quant_max – 最大量化值。未指定时根据 dtype 自动推断。

  • is_sync_quantize – 是否在使用多设备训练时同步统计信息。

  • factory_kwargs – 传递给 min_val 和 max_val 工厂函数的关键字参数。


forward(x_orig)

定义每次调用时执行的计算。

所有子类都应重写此方法。

提示:
尽管前向传播的“配方”必须在这个函数里定义,但之后应该调用 Module 的实例,而不是直接调用这个函数,因为前者会运行已注册的钩子(hooks),而后者会静默地忽略它们。


class horizon_plugin_pytorch.quantization.MovingAverageMinMaxObserver(
    averaging_constant=0.01,
    dtype=torch.qint8,
    qscheme=torch.per_tensor_symmetric,
    quant_min=None,
    quant_max=None,
    is_sync_quantize=False,
    factory_kwargs=None
)

滑动平均 MinMax 观察器(MovingAverageMinMaxObserver)

用于基于滑动平均的 min/max 值计算量化参数的观察器模块。

该观察器基于输入张量最小值和最大值的滑动平均计算量化参数。模块记录输入张量的平均最小值和最大值,并使用这些统计量计算量化参数。

参数:

  • averaging_constant – 用于 min/max 的平滑系数。

  • dtype – 量化后的数据类型。

  • qscheme – 使用的量化方案,仅支持 per_tensor_symmetric。

  • reduce_range – 将量化数据类型的范围减少 1 位。

  • quant_min – 最小量化值。

  • quant_max – 最大量化值。

  • is_sync_quantize – 是否使用同步量化。

  • factory_kwargs – 用于注册数据缓冲区的参数。


forward(x_orig)

记录 x 的运行最小值和最大值。


class horizon_plugin_pytorch.quantization.MovingAveragePerChannelMinMaxObserver(
    averaging_constant=0.01,
    ch_axis=0,
    dtype=torch.qint8,
    qscheme=torch.per_channel_symmetric,
    quant_min=None,
    quant_max=None,
    is_sync_quantize=False,
    factory_kwargs=None
)

滑动平均逐通道 MinMax 观察器(MovingAveragePerChannelMinMaxObserver)

用于基于逐通道的滑动平均 min/max 值计算量化参数的观察器模块。

该观察器使用张量的 min/max 统计量计算逐通道量化参数。模块记录输入张量的运行最小值和最大值,并使用这些统计量计算量化参数。

参数:

  • averaging_constant – 用于 min/max 的平滑系数。

  • ch_axis – 通道轴。

  • dtype – 量化后的数据类型。

  • qscheme – 使用的量化方案,仅支持 per_channel_symmetric。

  • quant_min – 最小量化值。

  • quant_max – 最大量化值。

  • is_sync_quantize – 是否使用同步量化。

  • factory_kwargs – 用于注册数据缓冲区的参数。


forward(x_orig)

定义每次调用时执行的计算。

所有子类都应重写此方法。

提示:
尽管前向传播的“配方”必须在这个函数里定义,但之后应该调用 Module 的实例,而不是直接调用这个函数,因为前者会运行已注册的钩子(hooks),而后者会静默地忽略它们。

6.4.3.4. 量化感知训练指南

量化训练通过在模型中插入一些伪量化节点,从而使得通过量化训练得到的模型转换成定点模型时尽可能减少精度损失。 量化训练和传统的模型训练无异,开发者可以从零开始,搭建一个伪量化模型,然后对该伪量化模型进行训练。 由于部署的硬件平台有诸多限制,对于开发者来说,搞清这些限制,并且根据这些限制搭建伪量化模型门槛较高。量化训练工具通过在开发者提供的浮点模型上根据部署平台的限制自动插入伪量化量化算子的方法,降低开发者开发量化模型的门槛。

量化训练由于施加了各种限制,因此,一般来说,量化训练比纯浮点模型的训练更加困难。量化训练工具的目标是降低量化训练的难度,降低量化模型部署的工程难度。

流程和示例

虽然量化训练工具不强制要求用户从一个预训练的浮点模型开始,但是,经验表明,通常从预训练的高精度浮点模型开始量化训练能大大降低量化训练的难度。

from horizon_plugin_pytorch.quantization import get_default_qconfig
# 将模型转为 QAT 状态
default_qat_8bit_fake_quant_qconfig = get_default_qconfig(
    activation_fake_quant="fake_quant",
    weight_fake_quant="fake_quant",
    activation_observer="min_max",
    weight_observer="min_max",
    activation_qkwargs=None,
    weight_qkwargs={
        "qscheme": torch.per_channel_symmetric,
        "ch_axis": 0,
    },
)
default_qat_out_8bit_fake_quant_qconfig = get_default_qconfig(
    activation_fake_quant=None,
    weight_fake_quant="fake_quant",
    activation_observer=None,
    weight_observer="min_max",
    activation_qkwargs=None,
    weight_qkwargs={
        "qscheme": torch.per_channel_symmetric,
        "ch_axis": 0,
    },
)
qat_model = prepare_qat_fx(
    float_model,
    {
        "": default_qat_8bit_fake_quant_qconfig,
        "module_name": {
            "classifier": default_qat_out_8bit_fake_quant_qconfig,
        },
    },
).to(device)
# 加载 Calibration 模型中的量化参数
qat_model.load_state_dict(calib_model.state_dict())
# 进行量化感知训练
# 作为一个 filetune 过程,量化感知训练一般需要设定较小的学习率
optimizer = torch.optim.SGD(
    qat_model.parameters(), lr=0.0001, weight_decay=2e-4
)

for nepoch in range(epoch_num):
    # 注意此处对 QAT 模型 training 状态的控制方法
    qat_model.train()
    set_fake_quantize(qat_model, FakeQuantState.QAT)

    train_one_epoch(
        qat_model,
        nn.CrossEntropyLoss(),
        optimizer,
        None,
        train_data_loader,
        device,
    )

    # 注意此处对 QAT 模型 eval 状态的控制方法
    qat_model.eval()
    set_fake_quantize(qat_model, FakeQuantState.VALIDATION)

    # 测试 qat 模型精度
    top1, top5 = evaluate(
        qat_model,
        eval_data_loader,
        device,
    )
    print(
        "QAT model: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format(
            top1.avg, top5.avg
        )
    )

# 测试 quantized 模型精度
quantized_model = convert_fx(qat_model.eval()).to(device)

top1, top5 = evaluate(
    quantized_model,
    eval_data_loader,
    device,
)
print(
    "Quantized model: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format(
        top1.avg, top5.avg
    )
)

注意:

由于部署平台的底层限制,QAT 模型无法完全代表最终上板精度,请务必监控 quantized 模型精度,确保 quantized 模型精度正常,否则可能出现模型上板掉点问题。

由上述示例代码可以看到,与传统的纯浮点模型训练相比,量化训练多了两个步骤:

  1. prepare_qat_fx

  2. 加载 Calibration 模型参数

prepare_qat_fx

这一步骤的目标是对浮点网络进行变换,插入伪量化节点。

加载 Calibration 模型参数

通过加载 Calibration 得到的伪量化参数,来获得一个较好的初始化。

训练迭代

至此,完成了伪量化模型的搭建和参数的初始化,然后就可以进行常规的训练迭代和模型参数更新,并且监控 quantized 模型精度。

伪量化算子

量化训练和传统的浮点模型的训练主要区别在于插入了伪量化算子,并且,不同量化训练算法也是通过伪量化算子来体现的,因此,这里介绍一下伪量化算子。

备注:

由于 BPU 只支持对称量化,因此,这里以对称量化为例介绍。

伪量化过程

以 int8 量化训练为例,一般来说,伪量化算子的计算过程如下:

fake_quant_x = clip(round(x / scale),-128, 127) * scale

和 Conv2d 通过训练来优化 weight, bias 参数类似,伪量化算子要通过训练来优化 scale 参数。 然而,由于 round 作为阶梯函数,其梯度为 0,从而导致了伪量化算子无法直接通过梯度反向传播的方式进行训练。解决这一问题,通常有两种方案:基于统计的方法和基于“学习”的方法。

基于统计的方法

量化地目标是把 Tensor 中的浮点数通过 scale 参数均匀地映射到 int8 表示的 [-128, 127] 的范围上。既然是均匀映射,那么很容易得到 scale 的计算方法:

def compute_scale(x: Tensor):
    xmin, xmax = x.max(), maxv = x.min()
    return max(xmin.abs(), xmax.abs()) / 256.0

由于 Tensor 中数据分布不均匀以及外点问题,又衍生了不同的计算 xmin 和 xmax 的方法。可以参考 MovingAverageMinMaxObserver 等。

在工具中的使用方法请参考 default_qat_8bit_fake_quant_qconfig 及其相关接口。

基于学习的方法

虽然 round 的梯度为 0,研究者通过实验发现,在该场景下,如果直接设置其梯度为 1 也可以使得模型收敛到预期的精度。

def round_ste(x: Tensor):
    return (x.round() - x).detach() + x

在工具中的使用方法请参考 default_qat_8bit_lsq_quant_qconfig 及其相关接口。

有兴趣进一步了解的用户可以参考如下论文:Learned Step Size Quantization

6.4.3.5. 异构模型指南

异构模型介绍

异构模型是部署时一部分运行在 BPU 上,一部分运行在 CPU 上的模型,而非异构模型部署时则完全运行在 BPU 上。通常情况下,以下两类模型在部署时会成为异构模型:

  1. 包含 BPU 不支持算子的模型。

  2. 由于量化精度误差过大,用户指定某些算子运行在 CPU 上的模型。

使用流程

hybrid_qat_workflow

通过 prepare 将浮点模型转为 QAT 模型,训练之后导出为 onnx 格式模型,由 hb_mapper 工具转为 bin 模型。

备注:

用户可以通过 convert 过程得到异构定点模型,用于模型精度评测。

算子限制

由于异构模型对接的是 horizon_nn,因此,其算子的支持情况和 horizon_nn 相同。

主要接口参数说明

horizon_plugin_pytorch.quantization.prepare_qat_fx

  1. 设置 hybrid=True 来开启异构模型功能。

  2. 用户可以通过设置 hybrid_dict 参数来强制指定某些 BPU 支持的算子跑在 CPU 上。

def prepare_qat_fx(
    model: Union[torch.nn.Module, GraphModule],
    qconfig_dict: Dict[str, Any] = None,
    prepare_custom_config_dict: Dict[str, Any] = None,
    optimize_graph: bool = False,
    hybrid: bool = False,
    hybrid_dict: Dict[str, List] = None,
) -> ObservedGraphModule:
    """Prepare QAT 模型
        `model`: torch.nn.Module 或 GraphModule(使用 fuse_fx 后的模型)
        `qconfig_dict`: 定义 Qconfig。如果除了 qconfig_dict 以外,还使用了 eager mode 在 module 内定义 qconfig 的方式,则 module 内定义的 qconfig 优先生效。qconfig_dict 的配置格式如下:
            qconfig_dict = {
                # 可选,全局配置
                "": qconfig,
                # 可选,按 module 类型配置
                "module_type": [(torch.nn.Conv2d, qconfig), ...],
                # 可选,按 module 名配置
                "module_name": [("foo.bar", qconfig),...],
                # 优先级:global < module_type < module_name < module.qconfig
                # 非 module 类型的算子的 qconfig 默认与其父 module 的 qconfig 保持一致,如果需要单独设置,请将这部分单独封装成 module。
            }
        `prepare_custom_config_dict`: 自定义配置字典
            prepare_custom_config_dict = {
                # 暂时只支持 preserved_attributes。一般而言会自动保留所有属性,这个选项只是以防万一,几乎不会用到。
                "preserved_attributes": ["preserved_attr"],
            }
        `optimize_graph`: 保持 cat 输入输出 scale 一致,目前只有在 Bernoulli 架构下有效。
        `hybrid`: 是否使用异构模式。在以下情况下必须打开异构模式:
            1. 模型包含 BPU 不支持的算子或用户希望指定部分 BPU 算子退回 CPU。
            2. 用户希望 QAT 模型与 horizon_nn 对接进行定点化。
        `hybrid_dict`: 定义用户主动指定的 CPU 算子。
            hybrid_dict = {
                # 可选,按 module 类型配置
                "module_type": [torch.nn.Conv2d, ...],
                # 可选,按 module 名配置
                "module_name": ["foo.bar", ...],
                # 优先级:module_type < module_name
                # 与 qconfig_dict 类似,如果想要非 module 类型的算子运行在 CPU 上,需要将这部分单独封装成 module。
            }
    """

horizon_plugin_pytorch.utils.onnx_helper.export_to_onnx

导出 onnx 模型,从而对接 hb_mapper

备注:

该接口也支持非异构模型,其导出的 ONNX 格式模型仅用于可视化。

def export_to_onnx(
    model,
    args,
    f,
    export_params=True,
    verbose=False,
    training=TrainingMode.EVAL,
    input_names=None,
    output_names=None,
    operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH,
    opset_version=11,
    do_constant_folding=True,
    example_outputs=None,
    strip_doc_string=True,
    dynamic_axes=None,
    keep_initializers_as_inputs=None,
    custom_opsets=None,
    enable_onnx_checker=False,
):
    """此接口与 torch.onnx.export 基本一致,隐藏了无需修改的参数,需要的注意参数有:
        `model`: 需要 export 的模型
        `args`: 模型输入,用于 trace 模型
        `f`: 保存的 onnx 文件名或文件描述符
        `operator_export_type`: 算子导出类型
            1. 对于非异构模型,onnx 仅用于可视化,不需要保证实际可用,使用默认值 OperatorExportTypes.ONNX_FALLTHROUGH
            2. 对于异构模型,onnx 需要保证实际可用,使用 None 确保导出的为标准 onnx 算子。
        `opset_version`: 只能为 11,horizon_plugin_pytorch 在 opset 11 中注册了特定的映射规则。
        注意:如果使用公版 torch.onnx.export,需要确保上述参数设置正确,
        并且 import horizon_plugin_pytorch.utils._register_onnx_ops
        以向 opset 11 中注册特定的映射规则。
    """

horizon_plugin_pytorch.quantization.convert_fx

异构模式可以复用 convert_fx 把伪量化模型转换成异构量化模型,用于评测模型精度。

注意:

通过 convert_fx 得到的异构量化模型无法进行部署。目前仅用于评测模型精度。

def convert_fx(
    graph_module: GraphModule,
    convert_custom_config_dict: Dict[str, Any] = None,
    _remove_qconfig: bool = True,
) -> QuantizedGraphModule:
    """转换 QAT 模型,仅用于评测定点模型。
        `graph_module`: 经过 prepare->(calibration)->train 之后的模型
        `convert_custom_config_dict`: 自定义配置字典
            convert_custom_config_dict = {
                # 暂时只支持 preserved_attributes。一般而言会自动保留所有属性,这个选项只是以防万一,几乎不会用到。
                "preserved_attributes": ["preserved_attr"],
            }
        `_remove_qconfig`: convert 之后是否删除 qconfig,一般不会用到
    """

流程和示例

  1. 改造浮点模型。

    • 插入 QuantStubDeQuantStub ,保持与非异构的用法一致。

      • 如果第一个 op 是 cpu op ,那么不需要插入 QuantStub

      • 如果最后一个 op 是 cpu op ,那么可以不用插入 DeQuantStub

    • 对于非 module 的运算,如果需要单独设置 qconfig 或指定其运行在 CPU 上,需要将其封装成 module ,参考示例中的 _SeluModule

  2. 设置 marchX3 设置bernoulli2, X5 设置为bayes-e。

  3. 设置 qconfig 。保留非异构模式下在 module 内设置 qconfig 的配置方式,除此以外,还可以通过 prepare_qat_fx 接口的 qconfig_dict 参数传入 qconfig,具体用法见接口参数说明。

    • 对于 BPU op ,必须保证有 qconfig ,如果其输入 op 不为 QuantStub ,那么还需要保证该输入 op 有 activation qconfig

    • 对于 CPU opqconfig 不会对其产生任何影响,但如果后面接 BPU op ,则必须有 qconfig

    • 推荐设置方式:先设置全局 qconfighorizon.quantization.default_qat_8bit_fake_quant_qconfig (或者 horizon.quantization.default_calib_8bit_fake_quant_qconfig ,根据 calibration 或 qat 阶段选择) ,在此基础上根据需求修改,一般而言,只需要对 int16 和高精度输出的 op 单独设置 qconfig

注意:

目前只有BPU架构为 BAYES_EX5 支持设置 int16 量化。

  1. 设置 hybrid_dict 。可选,具体用法见接口参数说明,如果没有主动指定的 CPU 算子,可以不设置 hybrid_dict

  2. 调用 prepare_qat_fx 并进行 calibration 。参考 horizon_plugin_pytorch 开发指南章节中的 Calibration 小节内容。

  3. 调用 prepare_qat_fx ,加载 calibration 模型并进行 QAT 训练。参考 horizon_plugin_pytorch 开发指南章节中的 量化训练 小节内容。

  4. 调用 convert_fx 。可选,没有评测定点模型精度的需求时可以跳过。

  5. 调用 export_to_onnx 。也可以使用 torch.onnx.export 但需要遵守 export_to_onnx 接口说明中的注意事项。

  6. 使用 hb_mapper 转换 onnx 模型。转换后需检查算子是否运行在预期的设备上,在部分情况下, hb_mapper 仍然需要设置 run_on_cpu 参数。比如:虽然 conv 在 QAT 阶段没有量化,但由于其输入(上一个算子输出)经过了伪量化, hb_mapper 仍然会默认将其量化。

hybrid_qat_run_on_cpu

import copy
import numpy as np
import torch
from horizon_plugin_pytorch.march import March, set_march
from horizon_plugin_pytorch.nn import qat
from horizon_plugin_pytorch.quantization import (
    prepare_qat_fx,
    convert_fx,
    set_fake_quantize,
    FakeQuantState,
    load_observer_params,
)
from horizon_plugin_pytorch.quantization.qconfig import (
    default_calib_8bit_fake_quant_qconfig,
    default_calib_out_8bit_fake_quant_qconfig,
    default_qat_8bit_fake_quant_qconfig,
    default_qat_out_8bit_fake_quant_qconfig,
)
from torch import nn
from torch.quantization import DeQuantStub, QuantStub
from horizon_plugin_pytorch.utils.onnx_helper import export_to_onnx

class _ConvBlock(nn.Module):
    def __init__(self, channels=3):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, 1)
        self.prelu = torch.nn.PReLU()

    def forward(self, input):
        x = self.conv(input)
        x = self.prelu(x)
        return torch.nn.functional.selu(x)

# 封装 functional selu 为 module,便于单独设置
class _SeluModule(nn.Module):
    def forward(self, input):
        return torch.nn.functional.selu(input)

class HybridModel(nn.Module):
    def __init__(self, channels=3):
        super().__init__()
        # 插入 QuantStub
        self.quant = QuantStub()
        self.conv0 = nn.Conv2d(channels, channels, 1)
        self.prelu = torch.nn.PReLU()
        self.conv1 = _ConvBlock(channels)
        self.conv2 = nn.Conv2d(channels, channels, 1)
        self.conv3 = nn.Conv2d(channels, channels, 1)
        self.conv4 = nn.Conv2d(channels, channels, 1)
        self.selu = _SeluModule()
        # 插入 DequantStub
        self.dequant = DeQuantStub()
        self.identity = torch.nn.Identity()

    def forward(self, input):
        x = self.quant(input)
        x = self.conv0(x)
        x = self.identity(x)
        x = self.prelu(x)
        x = torch.nn.functional.selu(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.identity(x)
        x = self.conv4(x)
        x = self.selu(x)
        return self.dequant(x)

# 设置 march **X3** 设置BERNOULLI2, **X5** 设置为BAYES_E。
set_march(March.XXX)
data_shape = [1, 3, 224, 224]
data = torch.rand(size=data_shape)
model = HybridModel()
qat_model = copy.deepcopy(model)
# float 模型的推理不要放在 prepare_qat_fx 之后,prepare_qat_fx 会对 float 模型做 inplace 修改
float_res = model(data)

calibration_model = prepare_qat_fx(
    model,
    {
        "": default_calib_8bit_fake_quant_qconfig,
        # selu 为 cpu 算子,conv4 实际上是 bpu 模型的输出,设置为高精度输出
        "module_name": [("conv4", default_calib_out_8bit_fake_quant_qconfig)]
    },
    hybrid=True,
    hybrid_dict={
        "module_name": ["conv1.conv", "conv3"],
        "module_type": [_SeluModule],
    },
)
# calibration 阶段需确保原有模型不会发生变化
calibration_model.eval()
set_fake_quantize(calibration_model, FakeQuantState.CALIBRATION)

for i in range(5):
    calibration_model(torch.rand(size=data_shape))

qat_model = prepare_qat_fx(
    qat_model,
    {
        "": default_qat_8bit_fake_quant_qconfig,
        # selu 为 cpu 算子,conv4 实际上是 bpu 模型的输出,设置为高精度输出
        "module_name": [("conv4", default_qat_out_8bit_fake_quant_qconfig)]
    },
    hybrid=True,
    hybrid_dict={
        "module_name": ["conv1.conv", "conv3"],
        "module_type": [_SeluModule],
    },
)

load_observer_params(calibration_model, qat_model)
set_fake_quantize(calibration_model, FakeQuantState.QAT)

# qat training start
# ......
# qat training end

# 导出 qat.onnx
export_to_onnx(
    qat_model,
    data,
    "qat.onnx",
    operator_export_type=None,
)

# 评测定点模型
quantize_model = convert_fx(qat_model)
quantize_res = quantize_model(data)

打印 QAT 模型的结果。

HybridModel(
  (quant): QuantStub(
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1,         scale=tensor([0.0078]), zero_point=tensor([0])
      (activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([-0.9995]), max_val=tensor([0.9995]))
    )
  )
  (conv0): Conv2d(
    3, 3, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0,         scale=tensor([0.0038, 0.0041, 0.0016]), zero_point=tensor([0, 0, 0])
      (activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([-0.4881, -0.4944,  0.0787]), max_val=tensor([-0.1213,  0.5284,  0.1981]))
    )
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1,         scale=tensor([0.0064]), zero_point=tensor([0])
      (activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([-0.8159]), max_val=tensor([0.8159]))
    )
  )
  (prelu): PReLU(num_parameters=1)
  (conv1): _ConvBlock(
    (conv): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
    (prelu): PReLU(num_parameters=1)
  )
  (conv2): Conv2d(
    3, 3, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0,         scale=tensor([0.0040, 0.0044, 0.0040]), zero_point=tensor([0, 0, 0])
      (activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([-0.5044, -0.4553, -0.5157]), max_val=tensor([0.1172, 0.5595, 0.4104]))
    )
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1,         scale=tensor([0.0059]), zero_point=tensor([0])
      (activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([-0.7511]), max_val=tensor([0.7511]))
    )
  )
  (conv3): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
  (conv4): Conv2d(
    3, 3, kernel_size=(1, 1), stride=(1, 1)
    (weight_fake_quant): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0,         scale=tensor([0.0025, 0.0037, 0.0029]), zero_point=tensor([0, 0, 0])
      (activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([-0.2484, -0.4718, -0.3689]), max_val=tensor([ 0.3239, -0.0056,  0.3312]))
    )
    (activation_post_process): None
  )
  (selu): _SeluModule()
  (dequant): DeQuantStub()
  (identity): Identity()
  (prelu_input_dequant): DeQuantStub()
  (selu_1_activation_post_process): _WrappedCalibFakeQuantize(
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1,         scale=tensor([0.0042]), zero_point=tensor([0])
      (activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([-0.5301]), max_val=tensor([0.5301]))
    )
  )
  (conv3_activation_post_process): _WrappedCalibFakeQuantize(
    (activation_post_process): FakeQuantize(
      fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8),            quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1,         scale=tensor([0.0072]), zero_point=tensor([0])
      (activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([-0.9156]), max_val=tensor([0.9156]))
    )
  )
  (conv3_input_dequant): DeQuantStub()
  (selu_2_input_dequant): DeQuantStub()
)

def forward(self, input):
    input_1 = input
    quant = self.quant(input_1);  input_1 = None
    conv0 = self.conv0(quant);  quant = None
    identity = self.identity(conv0);  conv0 = None
    prelu_input_dequant_0 = self.prelu_input_dequant(identity);  identity = None
    prelu = self.prelu(prelu_input_dequant_0);  prelu_input_dequant_0 = None
    selu = torch.nn.functional.selu(prelu, inplace = False);  prelu = None
    conv1_conv = self.conv1.conv(selu);  selu = None
    conv1_prelu = self.conv1.prelu(conv1_conv);  conv1_conv = None
    selu_1 = torch.nn.functional.selu(conv1_prelu, inplace = False);  conv1_prelu = None
    selu_1_activation_post_process = self.selu_1_activation_post_process(selu_1);  selu_1 = None
    conv2 = self.conv2(selu_1_activation_post_process);  selu_1_activation_post_process = None
    conv3_input_dequant_0 = self.conv3_input_dequant(conv2);  conv2 = None
    conv3 = self.conv3(conv3_input_dequant_0);  conv3_input_dequant_0 = None
    conv3_activation_post_process = self.conv3_activation_post_process(conv3);  conv3 = None
    identity_1 = self.identity(conv3_activation_post_process);  conv3_activation_post_process = None
    conv4 = self.conv4(identity_1);  identity_1 = None
    selu_2_input_dequant_0 = self.selu_2_input_dequant(conv4);  conv4 = None
    selu_2 = torch.nn.functional.selu(selu_2_input_dequant_0, inplace = False);  selu_2_input_dequant_0 = None
    dequant = self.dequant(selu_2);  selu_2 = None
    return dequant

导出的 onnx 如图所示,红色圈出部分为 CPU 算子。

hybrid_qat_onnx

6.4.3.6. 精度调优工具使用指南

由于浮点转定点过程中存在误差,当您在使用量化训练工具时,难免会碰到量化模型精度掉点问题。通常来说,造成掉点的原因有两大类:

  1. 原有浮点模型不利于量化,如存在共享 op 或共享结构;

  2. QAT 网络结构或配置异常,如模型中存在没有 fuse 的 pattern,没有设置高精度输出等;

  3. 某些算子对量化比较敏感,该算子的量化误差在前向传播过程中逐层累积,最终导致模型输出误差较大。

针对上述情况,量化训练工具提供了精度调优工具来帮助您快速定位并解决精度问题,主要包括如下模块:

  • 模型结构检查:检查模型中是否存在共享 op、没有 fuse 的 pattern 或者不符合预期的量化配置;

  • QuantAnalysis:自动比对分析两个模型,定位到量化模型中异常算子或者量化敏感 op;

  • ModelProfiler:获得模型中每一个 op 的数值特征信息,如输入输出的最大最小值等。

快速上手

当碰到量化模型精度掉点问题时,我们推荐按照如下的流程使用精度调优工具。

  1. 检查模型中是否存在不利于量化的结构或者异常配置;

  2. 使用 QuantAnalysis 模块进行分析,具体步骤如下:

    1. 找到一个 bad case 作为模型的输入。bad case 是指基准模型和待分析模型输出相差最大的那个输入;

    2. 进行量化敏感度分析,目前的经验是 L1 敏感度排序前 n 个通常为量化敏感 op(不同的模型 n 的数值不一样,暂无自动确定的方法,需要手动尝试,如前 10 个,20 个…)。将量化敏感 op 设置高精度量化(如 int16 量化),重新进行量化流程;

    3. 或者逐层比较两个模型的输入输出等信息,检查是否存在数据范围过大或者 scale 不合理等量化异常的 op,如某些具有物理含义的 op 应设置固定 scale。

整体的流程图如下:

new_debug_flow

一个完整的例子如下。


    from copy import deepcopy

    import torch
    from torch import nn
    from torch.quantization import DeQuantStub, QuantStub

    from horizon_plugin_pytorch.march import March, set_march
    from horizon_plugin_pytorch.quantization.qconfig import (
        default_qat_8bit_fake_quant_qconfig,
    )
    from horizon_plugin_pytorch.quantization.quantize_fx import prepare_qat_fx
    from horizon_plugin_pytorch.quantization import hbdk4 as hb4
    from horizon_plugin_pytorch.utils.check_model import check_qat_model
    from horizon_plugin_profiler import QuantAnalysis, ModelProfiler


    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv = nn.Conv2d(3, 3, 1)
            self.relu = nn.ReLU()
            self.quant = QuantStub()
            self.dequant = DeQuantStub()
        
        def forward(self, x):
            x = self.quant(x)
            x = self.conv(x)
            x = self.relu(x)
            x = torch.nn.functional.interpolate(
                x, scale_factor=1.3, mode="bilinear", align_corners=False
            )
            x = self.dequant(x)
            return x

    data = torch.rand((1, 3, 32, 32))
    float_net = Net()
    float_net(data)

    set_march(March.XXX)
    float_net.qconfig = default_qat_8bit_fake_quant_qconfig

    qat_net = deepcopy(float_net)
    qat_net = prepare_qat_fx(qat_net)

    ############################### 模型结构检查 ##############################
    # 确认提示的异常层是否符合预期
    check_qat_model(qat_net, data, save_results=True)
    ##########################################################################

    qat_net(data)

    quantized_net = deepcopy(qat_net)
    quantized_net = convert_fx(quantized_net)

    ############################### quant analysis ############################

    # 1. 初始化
    qa = QuantAnalysis(
        baseline_model=float_net,
        analysis_model=qat_net,
        analysis_model_type="fake_quant",
        out_dir="./floatvsqat",
    )

    # 也支持对比 qat 和 quantized
    # qa = QuantAnalysis(
    #     baseline_model=qat_net,
    #     analysis_model=quantized_net,
    #     analysis_model_type="quantized",
    #     out_dir="./qatvsquantized",
    # )

    # 2. 设置 badcase 输入。
    qa.set_bad_case(data)

    # 实际场景下推荐使用 auto_find_bad_case 在整个 dataloader 上搜索 bad case
    # 也支持设置 num_steps 参数来控制搜索的范围
    # qa.auto_find_bad_case(your_dataloader, num_steps=100)

    # 3. 运行两个模型
    qa.run()

    # 4. 两个模型逐层比较。确认 abnormal_layer_advisor.txt 提示的异常层是否符合预期
    # qa.compare_per_layer()

    # 5. 计算敏感度节点。可以将 topk 排序的敏感度节点设置高精度来尝试提升量化模型精度
    qa.sensitivity()

    ##########################################################################

API Reference

模型结构检查

    # from horizon_plugin_pytorch.utils.check_model import check_qat_model

    def check_qat_model(
        model: torch.nn.Module,
        example_inputs: Any,
        save_results: bool = False,
        out_dir: Optional[str] = None,
    ):

检查 calibration/qat 模型中是否存在不利于量化的结构以及量化 qconfig 配置是否符合预期。

参数

  • model: 待检查模型

  • example_inputs: 模型输入

  • save_results: 是否将检查结果保存到 txt 文件。默认 False。

  • out_dir: 结果文件 ‘model_check_result.txt’ 的保存路径。默认空,保存到当前路径下。

输出

  • 屏幕输出:检查出的异常层

  • model_check_result.txt:在 save_results = True 时生成。主要由5部分组成

    1. 未 fuse 的 pattern

    2. 每个 module 的调用次数。正常每个 op 仅调用 1 次,0 表示未被调用,超过 1 次则表示被共享了多次;

    3. 每个 op 输出的 qconfig 配置;

    4. 每个 op weight(如果有的话)的 qconfig 配置;

    5. 异常 qconfig 提示(如果有的话)。

Fusable modules are listed below:
name    type
------  -----------------------------------------------------
conv    <class 'horizon_plugin_pytorch.nn.qat.conv2d.Conv2d'>
relu    <class 'horizon_plugin_pytorch.nn.qat.relu.ReLU'>



Each module called times:
name       called times
-------  --------------
conv                  1
relu                  1
quant                 1
dequant               1

Each layer out qconfig:
+---------------+-----------------------------------------------------------+---------------+---------------+----------------+-----------------------------+
| Module Name   | Module Type                                               | Input dtype   | out dtype     | ch_axis        | observer                    |
|---------------+-----------------------------------------------------------+---------------+---------------+----------------+-----------------------------|
| quant         | <class 'horizon_plugin_pytorch.nn.qat.stubs.QuantStub'>   | torch.float32 | qint8         | -1             | MovingAverageMinMaxObserver |
| conv          | <class 'horizon_plugin_pytorch.nn.qat.conv2d.Conv2d'>     | qint8         | qint8         | -1             | MovingAverageMinMaxObserver |
| relu          | <class 'horizon_plugin_pytorch.nn.qat.relu.ReLU'>         | qint8         | qint8         | qconfig = None |                             |
| dequant       | <class 'horizon_plugin_pytorch.nn.qat.stubs.DeQuantStub'> | qint8         | torch.float32 | qconfig = None |                             |
+---------------+-----------------------------------------------------------+---------------+---------------+----------------+-----------------------------+

Weight qconfig:
+---------------+-------------------------------------------------------+----------------+-----------+---------------------------------------+
| Module Name   | Module Type                                           | weight dtype   |   ch_axis | observer                              |
|---------------+-------------------------------------------------------+----------------+-----------+---------------------------------------|
| conv          | <class 'horizon_plugin_pytorch.nn.qat.conv2d.Conv2d'> | qint8          |         0 | MovingAveragePerChannelMinMaxObserver |
+---------------+-------------------------------------------------------+----------------+-----------+---------------------------------------+
`prepare_qat/prepare_qat_fx` 流程中也已集成该接口,您可以设置 `verbose=1` 打开该检查功能。我们推荐您在进行 QAT 训练之前,使用此接口进行检查,并根据检查结果对模型做针对性的调整。

QuantAnalysis 类

QuantAnalysis 类可以自动寻找两个模型输出最大的 bad case,并以此作为输入,逐层比较两个模型的输出。此外,QuantAnalysis 类还提供计算敏感度功能,您可以尝试将敏感度排名 topk 的节点设置高精度,如 int16 量化,来提升量化模型精度。

class QuantAnalysis(object):
    def __init__(
        self,
        baseline_model: torch.nn.Module,
        analysis_model: torch.nn.Module,
        analysis_model_type: str,
        out_dir: Optional[str] = None,
    )

参数

  • baseline_model: 基准模型(高精度)

  • analysis_model:待分析的模型(精度掉点)

  • analysis_model_type: 待分析的模型类型。支持两种输入

    • fake_quant:待分析的模型可以是精度掉点的 calibration/qat 模型,此时基准模型可以是原始浮点模型或者一个精度达标的 int8/int16 混合配置的 calibration/qat 模型

    • quantized:待分析的模型是精度掉点的定点问题,此时基准模型必须是一个精度达标的 calibration/qat 模型

  • out_dir:指定比较结果的输出目录

该类中各个 method 如下。

auto_find_bad_case
    def auto_find_bad_case(
        self,
        data_generator: Iterable,
        num_steps: Optional[int] = None,
        metric: str = "L1",
        device: Optional[Union[torch.device, str, int]] = None,
        custom_metric_func: Optional[Callable] = None,
        custom_metric_order_seq: Optional[str] = None,
    ):

自动寻找导致两个模型输出最差的 badcase。

参数

  • data_generator:dataloader 或者一个自定义的迭代器,每次迭代产生一个数据

  • num_steps:迭代 steps 次数

  • metric:指定何种 metric 作为 badcase 的 metric。默认使用 L1 最差的结果。支持 Cosine/MSE/L1/KL/SQNR/custom。若为 custom,表示使用自定义的 metric 计算方法,此时 custom_metric_func 和 custom_metric_order_seq 两个参数必须不为 None

  • device:指定模型运行 device

  • custom_metric_func:自定义模型输出比较函数

  • custom_metric_order_seq:自定义模型输出比较函数的排序规则,仅支持 “ascending”/”descending”,表示升序/降序

set_bad_case
    def set_bad_case(self, data)

手动设置 badcase。

参数

  • data: badcase输入

load_bad_case
    def load_bad_case(self, filename: Optional[str] = None)

从指定的文件中加载 badcase。

参数

  • filename:指定的文件路径

save_bad_case
    def save_bad_case(self)

将 badcase 保存到 {self.out_dir}/badcase.pt 文件。

set_model_profiler_dir
    def set_model_profiler_dir(
        self,
        baseline_model_profiler_path: str,
        analysis_model_profiler_path: str,
    ):

手动指定 model_profiler 的输出保存路径。

某些情况下,在 QuantAnalysis 初始化之前,ModelProfiler 就已定义并运行,此时可以直接指定已有的 ModelProfiler 路径,跳过 QuantAnalysis 的 run 步骤,直接比较两个模型的输出。

参数

  • baseline_model_profiler_path:基准模型的 profiler 路径

  • analysis_model_profiler_path:待分析模型的 profiler 路径

run
    def run(
        self,
        device: Optional[Union[torch.device, str, int]] = None,
    )

运行两个模型并分别保存模型中每一层的结果。

参数

  • device:模型运行的 device

compare_per_layer
    def compare_per_layer(self)

比较两个模型中每一层的结果。

输出

  • abnormal_layer_advisor.txt: 所有异常层,包括相似度低/数据范围过大/输入没有归一化/输出没有高精度 等情况

  • profiler.html: 可视化展示所有 metric 指标及模型中每一层的数据范围 diff

profiler_html

  • compare_per_layer_out.txt: 以表格的形式展示模型中每层 layer 的具体信息,包括各种指标、数据范围、量化 dtype 等。从左到右每一列分别表示:

    • Index:op index

    • mod_name:该 op 名字,若 op 为 module 类型,则显示该 module 在模型中的 prefix name,若为 function 类型,则不显示

    • base_op_type:基准模型中该 op 的 type,可能是 module 类型或者 function 名称

    • analy_op_type:待分析模型中该 op 的 type,可能是 module 类型或者 function 名称

    • Shape:该 op 输出的 shape

    • quant_dtype:该 op 输出的量化类型

    • Qscale:该 op 输出的量化 scale

    • Cosine:该 op 在两个模型中输出的余弦相似度

    • MSE:该 op 在两个模型中输出的 MSE 距离

    • L1:该 op 在两个模型中输出的 L1 距离

    • KL:该 op 在两个模型中输出的 KL 相似度

    • SQNR:该 op 在两个模型中输出的 SQNR 相似度

    • Atol:该 op 在两个模型中输出的绝对误差

    • Rtol:该 op 在两个模型中输出的相对误差

    • base_model_min:基准模型中该 op 输出的最小值

    • analy_model_min:待分析模型中该 op 输出的最小值

    • base_model_max:基准模型中该 op 输出的最大值

    • analy_model_max:待分析模型中该 op 输出的最大值

    • base_model_mean:基准模型中该 op 输出的平均值

    • analy_model_mean:待分析模型中该 op 输出的平均值

    • base_model_var:基准模型中该 op 输出的方差

    • analy_model_var:待分析模型中该 op 输出的方差

    +----+------------+--------------------------------------------------------------------+--------------------------------------------------------------------+----------------------------+---------------+-----------+-----------+-----------+-----------+-----------+------------+-----------+-------------------------------------------------+------------------+-------------------+------------------+-------------------+-------------------+--------------------+------------------+-------------------+
    |    | mod_name   | base_op_type                                                       | analy_op_type                                                      | shape                      | quant_dtype   |    qscale |    Cosine |       MSE |        L1 |        KL |       SQNR |      Atol |                                            Rtol |   base_model_min |   analy_model_min |   base_model_max |   analy_model_max |   base_model_mean |   analy_model_mean |   base_model_var |   analy_model_var |
    |----+------------+--------------------------------------------------------------------+--------------------------------------------------------------------+----------------------------+---------------+-----------+-----------+-----------+-----------+-----------+------------+-----------+-------------------------------------------------+------------------+-------------------+------------------+-------------------+-------------------+--------------------+------------------+-------------------|
    |  0 | quant      | torch.ao.quantization.stubs.QuantStub                              | horizon_plugin_pytorch.nn.qat.stubs.QuantStub                      | torch.Size([1, 3, 32, 32]) | qint8         | 0.0078354 | 0.9999924 | 0.0000052 | 0.0019757 | 0.0000006 | 48.1179886 | 0.0039178 |                                       1.0000000 |        0.0003164 |         0.0000000 |        0.9990171 |         0.9950994 |         0.5015678 |          0.5014852 |        0.0846284 |         0.0846521 |
    |  1 | conv       | torch.nn.modules.conv.Conv2d                                       | horizon_plugin_pytorch.nn.qat.conv2d.Conv2d                        | torch.Size([1, 3, 32, 32]) | qint8         | 0.0060428 | 0.9999037 | 0.0000085 | 0.0023614 | 0.0000012 | 37.1519432 | 0.0096008 |                                      48.2379990 |       -0.7708085 |        -0.7674332 |        0.4674263 |         0.4652941 |        -0.0411330 |         -0.0412943 |        0.0423415 |         0.0422743 |
    |  2 | relu       | torch.nn.modules.activation.ReLU                                   | horizon_plugin_pytorch.nn.qat.relu.ReLU                            | torch.Size([1, 3, 32, 32]) | qint8         | 0.0060428 | 0.9998640 | 0.0000037 | 0.0010231 | 0.0000004 | 35.5429153 | 0.0093980 |                                      48.2379990 |        0.0000000 |         0.0000000 |        0.4674263 |         0.4652941 |         0.0641222 |          0.0639115 |        0.0090316 |         0.0089839 |
    |  3 |            | horizon_plugin_pytorch.nn.interpolate.autocasted_interpolate_outer | horizon_plugin_pytorch.nn.interpolate.autocasted_interpolate_outer | torch.Size([1, 3, 41, 41]) | qint8         | 0.0060428 | 0.9234583 | 0.0012933 | 0.0245362 | 0.0001882 |  8.1621437 | 0.1928777 | 340282346638528859811704183484516925440.0000000 |        0.0000000 |         0.0000000 |        0.3509629 |         0.3504813 |         0.0643483 |          0.0639483 |        0.0043305 |         0.0043366 |
    |  4 | dequant    | torch.ao.quantization.stubs.DeQuantStub                            | horizon_plugin_pytorch.nn.qat.stubs.DeQuantStub                    | torch.Size([1, 3, 41, 41]) | torch.float32 |           | 0.9234583 | 0.0012933 | 0.0245362 | 0.0001882 |  8.1621437 | 0.1928777 | 340282346638528859811704183484516925440.0000000 |        0.0000000 |         0.0000000 |        0.3509629 |         0.3504813 |         0.0643483 |          0.0639483 |        0.0043305 |         0.0043366 |
    +----+------------+--------------------------------------------------------------------+--------------------------------------------------------------------+----------------------------+---------------+-----------+-----------+-----------+-----------+-----------+------------+-----------+-------------------------------------------------+------------------+-------------------+------------------+-------------------+-------------------+--------------------+------------------+-------------------+
    
  • compare_per_layer_out.csv: 以 csv 的格式展示每层的具体信息。内容和 compare_per_layer_out.txt 完全一致,csv 文件的存储格式方便您通过 excel 等软件打开分析。

sensitivity
    def sensitivity(
        self,
        device: Optional[torch.device] = None,
        metric: str = "L1",
        reserve: bool = False
    ):

模型中各个节点的敏感度排序。适用于 float 转 calibration/qat 的精度掉点问题。

sensitivity 函数不支持计算 hbir 模型的敏感度。

参数

  • device:指定模型运行的 device

  • metric:相似度排序的 metric,默认 L1,支持 Cosine/MSE/L1/KL/SQNR

  • reserve:是否反序打印敏感度节点,以支持将某些 int16 算子退回 int8 来提升上板性能

输出

  • sensitive_ops.txt。文件中按照量化敏感度从高到低的顺序排列 op。从左到右每一列分别表示:

    • op_name:op 名字,

    • sensitive_type:计算量化敏感的类型,包括三种

      • activation:仅量化该 op 输出的量化敏感度

      • weight:仅量化该 op 权重的量化敏感度

      • both:同时量化该 op 输出和权重的量化敏感度

    • op_type:op 类型

    • metric:计算敏感度的指标。按照敏感度从高到低的顺序排序。支持 Cosine/L1/MSE/KL/SQNR 五种指标。默认使用 L1。

      • L1:取值范围 [0, $+\infty$],数值越大则该 op 对量化越敏感(从大到小排序)

      • Cosine:取值范围 [0,1],越接近 0 则该 op 对量化越敏感(从小到大排序)

      • MSE:取值范围 [0, $+\infty$],数值越大则该 op 对量化越敏感(从大到小排序)

      • KL:取值范围 [0, $+\infty$],数值越大则该 op 对量化越敏感(从大到小排序)

      • SQNR:取值范围 [0, $+\infty$],数值越小则该 op 对量化越敏感(从小到大排序)

  • sensitive_ops.pt。使用 torch.save 保存的敏感度排序的列表,方便您后续加载使用。列表格式见返回值部分说明。

返回值

敏感度 List,List 中每个元素都是记录一个 op 敏感度信息的子 list。子 List 中从左到右每一项分别为 [op_name, sensitive_type, op_type, metric1, metric2, ...]

整个 List 示例如下。

[
    [op1, "activation", op1_type, L1],
    [op2, "activation", op2_type, L1],
    [op3, "activation", op3_type, L1],
    [op1, "weight", op1_type, L1],
    [op2, "both", op2_type, L1],
    ...
]

您可以将量化敏感度排名前 n 的 op 配置高精度(如int16)来尝试提升量化模型精度。

op_name    sensitive_type    op_type                                                         L1
---------  ----------------  -------------------------------------------------------  ---------
quant      activation        <class 'horizon_plugin_pytorch.nn.qat.stubs.QuantStub'>  0.0245567
conv       activation        <class 'horizon_plugin_pytorch.nn.qat.conv2d.Conv2d'>    0.0245275
conv       both              <class 'horizon_plugin_pytorch.nn.qat.conv2d.Conv2d'>    0.0245275
conv       weight            <class 'horizon_plugin_pytorch.nn.qat.conv2d.Conv2d'>    0.024501
clean
    def clean(self)

清除中间结果。仅保留比较结果等文件。

ModelProfiler 类

统计模型 forward 过程中,每一层算子的输入输出等信息。

# from horizon_plugin_profiler import ModelProfiler

class ModelProfiler(object):
    def __init__(
        self,
        model: torch.nn.Module,
        out_dir: str,
    )

参数

  • model: 需要统计的模型

  • out_dir: 相关文件保存的路径

该类仅支持通过 with 语句的方式使用。
with ModelProfiler(net, "./profiler_dir") as p:
    net(data)

p.get_info_manager.table()
p.get_info_manager.tensorboard()

该类中其中各个 method 如下。

get_info_manager
    def get_info_manager(self)

获得管理每个 op 信息的结构体。

返回值

管理存储的每个 op 信息的结构体 OpRunningInfoManager。其中两个重要的接口如下。

table
class OpRunningInfoManager:
    def table(
        self,
        out_dir: str = None,
        prefixes: Tuple[str, ...] = None,
        types: Tuple[Type, ...] = None,
        with_stack: bool = False,
    )

在一个表格中展示单个模型统计量。存储到 statistic.txt 文件中

参数

  • out_dir:statistic.txt 文件的存储路径,默认 None,存储到 self.out_dir

  • prefixes:需要统计的模型中 op 的 prefixes 。默认统计所有 op

  • types:需要统计的模型中 op 的 type。默认统计所有 op

  • with_stack: 是否显示每个 op 在代码中对应的位置

输出

statistic.txt 文件,从左到右每一列分别为:

  • Index: op index

  • Op Name:op type,module 类名或者 function 名

  • Mod Name:若是 module 类,则显示该 module 在模型中的 prefix name;若是 function 类型,则显示该 function 所在的 module prefix name。

  • Attr:input/output/weight/bias

  • Dtype:tensor 的数据类型

  • Scale:tensor 的 scale

  • Min:当前 tensor 的最小值

  • Max:当前 tensor 的最大值

  • Mean:当前 tensor 的平均值

  • Var:当前 tensor 中数值的方差

  • Shape:tensor shape

+---------+--------------------------------------------------------------------+------------+--------+---------------+-----------+------------+-----------+------------+-----------+----------------------------+
| Index   | Op Name                                                            | Mod Name   | Attr   | Dtype         | Scale     | Min        | Max       | Mean       | Var       | Shape                      |
|---------+--------------------------------------------------------------------+------------+--------+---------------+-----------+------------+-----------+------------+-----------+----------------------------|
| 0       | horizon_plugin_pytorch.nn.qat.stubs.QuantStub                      | quant      | input  | torch.float32 |           | 0.0003164  | 0.9990171 | 0.5015678  | 0.0846284 | torch.Size([1, 3, 32, 32]) |
| 0       | horizon_plugin_pytorch.nn.qat.stubs.QuantStub                      | quant      | output | qint8         | 0.0078354 | 0.0000000  | 0.9950994 | 0.5014852  | 0.0846521 | torch.Size([1, 3, 32, 32]) |
| 1       | horizon_plugin_pytorch.nn.qat.conv2d.Conv2d                        | conv       | input  | qint8         | 0.0078354 | 0.0000000  | 0.9950994 | 0.5014852  | 0.0846521 | torch.Size([1, 3, 32, 32]) |
| 1       | horizon_plugin_pytorch.nn.qat.conv2d.Conv2d                        | conv       | weight | torch.float32 |           | -0.5315086 | 0.5750652 | 0.0269936  | 0.1615299 | torch.Size([3, 3, 1, 1])   |
| 1       | horizon_plugin_pytorch.nn.qat.conv2d.Conv2d                        | conv       | bias   | torch.float32 |           | -0.4963555 | 0.4448483 | -0.0851902 | 0.2320642 | torch.Size([3])            |
| 1       | horizon_plugin_pytorch.nn.qat.conv2d.Conv2d                        | conv       | output | qint8         | 0.0060428 | -0.7674332 | 0.4652941 | -0.0412943 | 0.0422743 | torch.Size([1, 3, 32, 32]) |
| 2       | horizon_plugin_pytorch.nn.qat.relu.ReLU                            | relu       | input  | qint8         | 0.0060428 | -0.7674332 | 0.4652941 | -0.0412943 | 0.0422743 | torch.Size([1, 3, 32, 32]) |
| 2       | horizon_plugin_pytorch.nn.qat.relu.ReLU                            | relu       | output | qint8         | 0.0060428 | 0.0000000  | 0.4652941 | 0.0639115  | 0.0089839 | torch.Size([1, 3, 32, 32]) |
| 3       | horizon_plugin_pytorch.nn.interpolate.autocasted_interpolate_outer |            | input  | qint8         | 0.0060428 | 0.0000000  | 0.4652941 | 0.0639115  | 0.0089839 | torch.Size([1, 3, 32, 32]) |
| 3       | horizon_plugin_pytorch.nn.interpolate.autocasted_interpolate_outer |            | output | qint8         | 0.0060428 | 0.0000000  | 0.3504813 | 0.0639483  | 0.0043366 | torch.Size([1, 3, 41, 41]) |
| 4       | horizon_plugin_pytorch.nn.qat.stubs.DeQuantStub                    | dequant    | input  | qint8         | 0.0060428 | 0.0000000  | 0.3504813 | 0.0639483  | 0.0043366 | torch.Size([1, 3, 41, 41]) |
| 4       | horizon_plugin_pytorch.nn.qat.stubs.DeQuantStub                    | dequant    | output | torch.float32 |           | 0.0000000  | 0.3504813 | 0.0639483  | 0.0043366 | torch.Size([1, 3, 41, 41]) |
+---------+--------------------------------------------------------------------+------------+--------+---------------+-----------+------------+-----------+------------+-----------+----------------------------+
tensorboard
class OpRunningInfoManager:
    def tensorboard(
        self,
        out_dir: str = None,
        prefixes: Tuple[str, ...] = None,
        types: Tuple[Type, ...] = None,
        force_per_channel: bool = False,
    ):

在 tensorboard 中显示每一层输入输出直方图。

参数

  • out_dir: tensorboard 相关文件保目录。默认保存到 self.out_dir/tensorboard 目录下

  • prefixes:需要统计的模型中 op 的 prefixes。默认统计所有

  • types:需要统计的模型中 op 的 type。默认统计所有

  • force_per_channel:是否以 per_channel 量化的方式展示直方图

输出

tensorboard 文件,打开后截图如下。

tensorboard

6.4.3.7. 量化部署 PT 模型的跨设备 Inference 说明

量化部署的 pt 模型要求 trace 时使用的 device 和后续 infer 时使用的 device 一致。

若用户试图直接通过 to(device) 操作修改 pt 模型的 device,可能会出现模型 forward 报错的问题,torch 官方对此进行了解释,见 TorchScript-Frequently Asked Questions — PyTorch documentation

下面举例说明:

import torch


class Net(torch.nn.Module):
    def forward(self, x: torch.Tensor):
        y = torch.ones(x.shape, device=x.device)
        z = torch.zeros_like(x)

        return y + z


script_mod = torch.jit.trace(
    Net(), torch.rand(2, 3, 3, 3, device=torch.device("cpu"))
)
script_mod.to(torch.device("cuda"))
print(script_mod.graph)

# graph(%self : __torch__.Net,
#       %x : Float(2, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cpu)):
#   %4 : int = prim::Constant[value=0]()
#   %5 : int = aten::size(%x, %4)
#   %6 : Long(device=cpu) = prim::NumToTensor(%5)
#   %16 : int = aten::Int(%6)
#   %7 : int = prim::Constant[value=1]()
#   %8 : int = aten::size(%x, %7)
#   %9 : Long(device=cpu) = prim::NumToTensor(%8)
#   %17 : int = aten::Int(%9)
#   %10 : int = prim::Constant[value=2]()
#   %11 : int = aten::size(%x, %10)
#   %12 : Long(device=cpu) = prim::NumToTensor(%11)
#   %18 : int = aten::Int(%12)
#   %13 : int = prim::Constant[value=3]()
#   %14 : int = aten::size(%x, %13)
#   %15 : Long(device=cpu) = prim::NumToTensor(%14)
#   %19 : int = aten::Int(%15)
#   %20 : int[] = prim::ListConstruct(%16, %17, %18, %19)
#   %21 : NoneType = prim::Constant()
#   %22 : NoneType = prim::Constant()
#   %23 : Device = prim::Constant[value="cpu"]()
#   %24 : bool = prim::Constant[value=0]()
#   %y : Float(2, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cpu) = aten::ones(%20, %21, %22, %23, %24)
#   %26 : int = prim::Constant[value=6]()
#   %27 : int = prim::Constant[value=0]()
#   %28 : Device = prim::Constant[value="cpu"]()
#   %29 : bool = prim::Constant[value=0]()
#   %30 : NoneType = prim::Constant()
#   %z : Float(2, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cpu) = aten::zeros_like(%x, %26, %27, %28, %29, %30)
#   %32 : int = prim::Constant[value=1]()
#   %33 : Float(2, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cpu) = aten::add(%y, %z, %32)
#   return (%33)

可以看到,在调用 to(torch.device("cuda")) 后,模型的 graph 中记录的 aten::onesaten::zeros_like 的 device 参数仍为 prim::Constant[value="cpu"](),因此在模型 forward 时,它们的输出仍为 cpu Tensor。这是因为 to(device) 只能移动模型中的 buffer(weight、bias 等),无法修改 ScriptModule 的 graph。

torch 官方对以上限制给出的解决方案是,在 trace 前就确定好 pt 模型将要在哪个 device 上执行,并在对应的 device 上 trace 即可。

针对以上限制,训练工具建议根据具体场景选择以下解决方案:

PT 模型执行使用的 device 和 trace 不一致

对于可以确定 pt 模型将仅在 GPU 上执行,只需要修改卡号的情况,我们首先推荐使用 cuda:0,即零号卡进行 trace。在使用模型时,用户可以通过 torch.cuda.set_device 接口,将物理上的任意卡映射为逻辑上的“零卡”,此时使用 cuda:0 trace 出的模型实际将在指定的物理卡上运行。

若 trace 时使用的 device 和执行时使用的 device 存在 CPU、GPU 的不一致,用户可以使用 horizon_plugin_pytorch.jit.to_device 接口实现 pt 模型的 device 迁移。此接口会寻找模型 graph 中的 device 参数,并将它们替换为需要的值。效果如下:

from horizon_plugin_pytorch.jit import to_device

script_mod = to_device(script_mod, torch.device("cuda"))
print(script_mod.graph)

# graph(%self : __torch__.Net,
#       %x.1 : Tensor):
#   %38 : bool = prim::Constant[value=0]()
#   %60 : Device = prim::Constant[value="cuda"]()
#   %34 : NoneType = prim::Constant()
#   %3 : int = prim::Constant[value=0]()
#   %10 : int = prim::Constant[value=1]()
#   %17 : int = prim::Constant[value=2]()
#   %24 : int = prim::Constant[value=3]()
#   %41 : int = prim::Constant[value=6]()
#   %4 : int = aten::size(%x.1, %3)
#   %5 : Tensor = prim::NumToTensor(%4)
#   %8 : int = aten::Int(%5)
#   %11 : int = aten::size(%x.1, %10)
#   %12 : Tensor = prim::NumToTensor(%11)
#   %15 : int = aten::Int(%12)
#   %18 : int = aten::size(%x.1, %17)
#   %19 : Tensor = prim::NumToTensor(%18)
#   %22 : int = aten::Int(%19)
#   %25 : int = aten::size(%x.1, %24)
#   %26 : Tensor = prim::NumToTensor(%25)
#   %32 : int = aten::Int(%26)
#   %33 : int[] = prim::ListConstruct(%8, %15, %22, %32)
#   %y.1 : Tensor = aten::ones(%33, %34, %34, %60, %38)
#   %z.1 : Tensor = aten::zeros_like(%x.1, %41, %3, %60, %38, %34)
#   %50 : Tensor = aten::add(%y.1, %z.1, %10)
#   return (%50)

多卡并行推理

在此场景下,用户需要通过 trace 或 to_device 的方式取得 cuda:0 上的 pt 模型,并且为每块卡单独开启一个进程,通过 torch.cuda.set_device 的方式为每个进程设置不同的默认卡。一个简单的示例如下:

import os
import torch
import signal
import torch.distributed as dist
import torch.multiprocessing as mp
from horizon_plugin_pytorch.jit import to_device

model_path = "path_to_pt_model_file"

def main_func(rank, world_size, device_ids):
    torch.cuda.set_device(device_ids[rank])
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

    model = to_device(torch.jit.load(model_path), torch.device("cuda"))

    # 数据加载,模型 forward,精度计算等内容此处省略

def launch(device_ids):
    try:
        world_size = len(device_ids)
        mp.spawn(
            main_func,
            args=(world_size, device_ids),
            nprocs=world_size,
            join=True,
        )
    # 当按下 Ctrl+c 时,关闭所有子进程
    except KeyboardInterrupt:
        os.killpg(os.getpgid(os.getpid()), signal.SIGKILL)

launch([0, 1, 2, 3])

上述操作对 pt 模型的处理和 torch.nn.parallel.DistributedDataParallel 的做法一致,数据加载和模型精度计算相关内容请参考 Getting Started with Distributed Data Parallel — PyTorch Tutorials

6.4.3.8. 量化部署 PT 模型的跨设备 Inference 说明

量化部署的 pt 模型要求 trace 时使用的 device 和后续 infer 时使用的 device 一致。

若用户试图直接通过 to(device) 操作修改 pt 模型的 device,可能会出现模型 forward 报错的问题,torch 官方对此进行了解释,见 TorchScript-Frequently Asked Questions — PyTorch documentation

下面举例说明:

import torch


class Net(torch.nn.Module):
    def forward(self, x: torch.Tensor):
        y = torch.ones(x.shape, device=x.device)
        z = torch.zeros_like(x)

        return y + z


script_mod = torch.jit.trace(
    Net(), torch.rand(2, 3, 3, 3, device=torch.device("cpu"))
)
script_mod.to(torch.device("cuda"))
print(script_mod.graph)

# graph(%self : __torch__.Net,
#       %x : Float(2, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cpu)):
#   %4 : int = prim::Constant[value=0]()
#   %5 : int = aten::size(%x, %4)
#   %6 : Long(device=cpu) = prim::NumToTensor(%5)
#   %16 : int = aten::Int(%6)
#   %7 : int = prim::Constant[value=1]()
#   %8 : int = aten::size(%x, %7)
#   %9 : Long(device=cpu) = prim::NumToTensor(%8)
#   %17 : int = aten::Int(%9)
#   %10 : int = prim::Constant[value=2]()
#   %11 : int = aten::size(%x, %10)
#   %12 : Long(device=cpu) = prim::NumToTensor(%11)
#   %18 : int = aten::Int(%12)
#   %13 : int = prim::Constant[value=3]()
#   %14 : int = aten::size(%x, %13)
#   %15 : Long(device=cpu) = prim::NumToTensor(%14)
#   %19 : int = aten::Int(%15)
#   %20 : int[] = prim::ListConstruct(%16, %17, %18, %19)
#   %21 : NoneType = prim::Constant()
#   %22 : NoneType = prim::Constant()
#   %23 : Device = prim::Constant[value="cpu"]()
#   %24 : bool = prim::Constant[value=0]()
#   %y : Float(2, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cpu) = aten::ones(%20, %21, %22, %23, %24)
#   %26 : int = prim::Constant[value=6]()
#   %27 : int = prim::Constant[value=0]()
#   %28 : Device = prim::Constant[value="cpu"]()
#   %29 : bool = prim::Constant[value=0]()
#   %30 : NoneType = prim::Constant()
#   %z : Float(2, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cpu) = aten::zeros_like(%x, %26, %27, %28, %29, %30)
#   %32 : int = prim::Constant[value=1]()
#   %33 : Float(2, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cpu) = aten::add(%y, %z, %32)
#   return (%33)

可以看到,在调用 to(torch.device("cuda")) 后,模型的 graph 中记录的 aten::onesaten::zeros_like 的 device 参数仍为 prim::Constant[value="cpu"](),因此在模型 forward 时,它们的输出仍为 cpu Tensor。这是因为 to(device) 只能移动模型中的 buffer(weight、bias 等),无法修改 ScriptModule 的 graph。

torch 官方对以上限制给出的解决方案是,在 trace 前就确定好 pt 模型将要在哪个 device 上执行,并在对应的 device 上 trace 即可。

针对以上限制,训练工具建议根据具体场景选择以下解决方案:

PT 模型执行使用的 device 和 trace 不一致

对于可以确定 pt 模型将仅在 GPU 上执行,只需要修改卡号的情况,我们首先推荐使用 cuda:0,即零号卡进行 trace。在使用模型时,用户可以通过 torch.cuda.set_device 接口,将物理上的任意卡映射为逻辑上的“零卡”,此时使用 cuda:0 trace 出的模型实际将在指定的物理卡上运行。

若 trace 时使用的 device 和执行时使用的 device 存在 CPU、GPU 的不一致,用户可以使用 horizon_plugin_pytorch.jit.to_device 接口实现 pt 模型的 device 迁移。此接口会寻找模型 graph 中的 device 参数,并将它们替换为需要的值。效果如下:

from horizon_plugin_pytorch.jit import to_device

script_mod = to_device(script_mod, torch.device("cuda"))
print(script_mod.graph)

# graph(%self : __torch__.Net,
#       %x.1 : Tensor):
#   %38 : bool = prim::Constant[value=0]()
#   %60 : Device = prim::Constant[value="cuda"]()
#   %34 : NoneType = prim::Constant()
#   %3 : int = prim::Constant[value=0]()
#   %10 : int = prim::Constant[value=1]()
#   %17 : int = prim::Constant[value=2]()
#   %24 : int = prim::Constant[value=3]()
#   %41 : int = prim::Constant[value=6]()
#   %4 : int = aten::size(%x.1, %3)
#   %5 : Tensor = prim::NumToTensor(%4)
#   %8 : int = aten::Int(%5)
#   %11 : int = aten::size(%x.1, %10)
#   %12 : Tensor = prim::NumToTensor(%11)
#   %15 : int = aten::Int(%12)
#   %18 : int = aten::size(%x.1, %17)
#   %19 : Tensor = prim::NumToTensor(%18)
#   %22 : int = aten::Int(%19)
#   %25 : int = aten::size(%x.1, %24)
#   %26 : Tensor = prim::NumToTensor(%25)
#   %32 : int = aten::Int(%26)
#   %33 : int[] = prim::ListConstruct(%8, %15, %22, %32)
#   %y.1 : Tensor = aten::ones(%33, %34, %34, %60, %38)
#   %z.1 : Tensor = aten::zeros_like(%x.1, %41, %3, %60, %38, %34)
#   %50 : Tensor = aten::add(%y.1, %z.1, %10)
#   return (%50)

多卡并行推理

在此场景下,用户需要通过 trace 或 to_device 的方式取得 cuda:0 上的 pt 模型,并且为每块卡单独开启一个进程,通过 torch.cuda.set_device 的方式为每个进程设置不同的默认卡。一个简单的示例如下:

import os
import torch
import signal
import torch.distributed as dist
import torch.multiprocessing as mp
from horizon_plugin_pytorch.jit import to_device

model_path = "path_to_pt_model_file"

def main_func(rank, world_size, device_ids):
    torch.cuda.set_device(device_ids[rank])
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

    model = to_device(torch.jit.load(model_path), torch.device("cuda"))

    # 数据加载,模型 forward,精度计算等内容此处省略

def launch(device_ids):
    try:
        world_size = len(device_ids)
        mp.spawn(
            main_func,
            args=(world_size, device_ids),
            nprocs=world_size,
            join=True,
        )
    # 当按下 Ctrl+c 时,关闭所有子进程
    except KeyboardInterrupt:
        os.killpg(os.getpgid(os.getpid()), signal.SIGKILL)

launch([0, 1, 2, 3])

上述操作对 pt 模型的处理和 torch.nn.parallel.DistributedDataParallel 的做法一致,数据加载和模型精度计算相关内容请参考 Getting Started with Distributed Data Parallel — PyTorch Tutorials

6.4.3.9. 常见问题

import 出错

错误一:Cannot find the extension library(_C.so)

解决方法:

  • 确定 horizon_plugin_pytorch 版本和 cuda 版本是对应的

  • 在 python3 中,找到 horizon_plugin_pytorch 的执行路径,检测该目录下是否有 .so 文件。可能同时存在多个 horizon_plugin_pytorch 的版本,需要卸载只保留一个需要的版本。


错误二:RuntimeError: Cannot load custom ops. Please rebuild the horizon_plugin_pytorch

解决方法:确认本地 CUDA 环境是否正常,如路径、版本等


无法正常 prepare_calibration/qat

RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment

解决方法:一般是模型中包含 non-leaf tensor 才会出现这样的错误,尝试以下方法:

  • 将 prepare_calibration/qat 的 inplace 设为 True

  • 正常 horizon_plugin_pytorch 定义的算子不会出现这种错误,检查模型中自定义的算子是否有 non-leaf tensor 的定义。


prepare_qat 后 forward 报错

TypeError: when calling function <built-in method conv2d of type object at >

解决方法:自定义算子继承了某个 torch 的 Module 算子,导致 prepare_qat 没有被成功转成 qat module。建议使用 submodule 的方式调用 conv2d。


编译报错

ValueError 'unsupported node', aten::unbind

解决方法:将 tensor 当作 list 传入 zip 处理,最终调用了 tensor 原生的 iter,该方法内部使用了 unbind 操作导致以上错误。请检查您的代码。


量化精度异常

QAT/Quantized 精度不符合预期、出现 NAN 或 QAT 初始 loss 相对 float 明显异常

解决方法:请参考 量化训练精度调优指南


使用 torch.jit.load 加载 pt 文件报错

RuntimeError: Unknown builtin op: horizon::bpu_scale_quantization

解决方法:请检查在使用 torch.jit.load 前是否有 import horizon_plugin_pytorch。否则,加载时找不到对应的 horizon 算子。推荐使用 horizon.jit.save/load 保存和加载 pt 文件,避免这样的错误。此外,horizon.jit.save 在保存 pt 时还会额外保存 horizon_plugin_pytorch 的版本号,horizon.jit.load 会检查当前 horizon_plugin_pytorch 的版本是否和保存 pt 时的兼容,若不兼容,会输出相应的警告。

6.4.3.10. 常见使用误区

设置类错误

warning 错误: 无需量化的模块设置了非 None 的 qconfig,例如 前后处理,loss function 等。

正确做法:只需要量化的模块设置 qconfig。


warning 错误: 没有正确设置 march,这样可能导致模型编译失败或部署精度不一致。

正确做法:根据要部署的处理器选择正确的 BPU 架构,例如:


## X5 需要使用 Bayes-e
horizon.march.set_march(horizon.march.March.Bayes)

## X3 需要使用 Bernoulli2
horizon.march.set_march(horizon.march.March.Bernoulli2)

warning 错误: 模型输出节点没有设置成高精度输出,导致量化精度不符合预期。

错误示例如下: 假设模型定义如下:

class ToyNet(nn.Module):
    def __init__(self):
        self.conv0 = nn.Conv2d(4,4,3,3)
        self.relu0 = nn.ReLU()
        self.classifier = nn.Conv2d(4,4,3,3)

    def forward(self, x):
        out = self.conv0(x)
        out = self.relu(out)
        out = self.classifier(out)
        return out

# 错误的设置 qconfig 示例:

float_model = ToyNet()

qat_model = prepare_qat_fx(
    float_model,
    {
        "": default_qat_8bit_fake_quant_qconfig, # 整网设置成 int8 量化
    },
)

正确做法:为了提高模型精度,模型输出节点设置成高精度,示例如下:

qat_model = prepare_qat_fx(
    float_model,
    {
        "module_name": {
            "classifier": default_qat_out_8bit_fake_quant_qconfig, # 网络输出 classifier 层设置为高精度
        },
        "": default_qat_8bit_fake_quant_qconfig, # 其它层设置成 int8 量化
    },
)

方法类错误

warning 错误: Calibration 过程使用多卡。

由于底层限制,Calibration 目前不支持多卡,请使用单卡进行 Calibration 操作。


warning 错误: 模型输入图像数据采用数据格式为 RGB 等非 centered YUV444 格式,这样可能导致模型部署精度不一致。

正确做法:由于 Horizon 硬件支持的图像格式为 centered YUV444,因此建议用户从模型训练开始就直接使用 YUV444 格式作为网络输入进行训练。


warning 错误: 量化训练中使用 qat 模型进行模型精测评测和监控,导致不能及时发现部署时精度异常的问题。

正确做法:导致 QAT 与 Quantized 误差的原因是 QAT 阶段不能完全模拟 Quantized 中纯定点计算逻辑,建议使用 quantized 模型进行模型精度评测和监控。

quantized_model = convert_fx(qat_model.eval())
acc = evaluate(quantized_model, eval_data_loader, device)

网络类错误

warning 错误: 多次调用同一个通过 FloatFunctional() 定义的成员。

错误示例如下:

class ToyNet(nn.Module):
    def __init__(self):
        self.add = FloatFunctional()

    def forward(self, x, y, z)
        out = self.add(x, y)
        return self.add(out, z)

正确做法:禁止在 forward 中多次调用同一个通过 FloatFunctional() 定义的变量。

class ToyNet(nn.Module):
    def __init__(self):
        self.add0 = FloatFunctional()
        self.add1 = FloatFunctional()

    def forward(self, x, y, z)
        out = self.add0.add(x, y)
        return self.add1.add(out, z)

算子类错误

warning 错误: Quantized 模型中部分算子没有经过前期的 calibration 或 QAT,如某后处理算子想要在 BPU 上加速,但是没有经过量化阶段,这时候会导致量化 Inference 失败或部署时的精度异常。

正确做法:Quantized 阶段并非完全不能直接添加算子,如颜色空间转换算子等,具体添加指南详见文档。但是并非所有算子都可以直接添加,比如 cat,这种算子必须在 calibration 或 QAT 阶段统计获得的真实量化参数才能不影响最终精度,有类似需求需要调整网络结构,可以咨询框架研发。


模型类错误

warning 错误: 浮点模型过拟合。

模型过拟合常见判定方法:

  • 对输入数据稍加变换之后,输出结果变化较大

  • 模型参数赋值较大

  • 模型 activation 较大

正确做法:自行解决浮点模型过拟合问题。