写点什么

征程 6 J6E/M linear 双 int16 量化支持替代方案

  • 2025-05-21
    广东
  • 本文字数:5791 字

    阅读完需:约 19 分钟

征程 6 J6E/M linear 双int16量化支持替代方案

1.背景简介

当发现使用 plugin 精度 debug 工具定位到是某个 linear 敏感时,示例如下:


op_name                                sensitive_type    op_type                                                                          L1  quant_dtype    flops-------------------------------------  ---------------   -----------------------------  ----------------  -------------------------  -------  -------------  --------------model.layernorm.rsqrt                  activation        <class 'horizon_plugin_pytorch.nn.qat.segment_lut.SegmentLUT'>              6.52537  qint16         0(0%)model.linear2                          weight            <class 'horizon_plugin_pytorch.nn.qat.linear.Linear'>                       5.02445  qint8          3072000(0.00%)model.layernorm.var_mean.pre_mean      activation        <class 'horizon_plugin_pytorch.nn.qat.functional_modules.FloatFunctional'>  3.1683   qint16         0(0%)
复制代码


可以发现,model.linear2 weight 排在了前面,且是 int8 量化。


接下来看下 baseline_statistic.txt 与 analysis_statistic.txt,其中有 model.linear2 的 input、weight、output 的数值分布范围,示例如下:


| Op Name                            | Mod Name       | Attr     | Min            | Max            | Mean           | Var        | Shape                       ||---------------------------------------------------------------------------------------------------------------------------------------------------------------| torch.nn.modules.linear.Linear     | model.linear2  | input    | 0.0000000      | 15.4210167     | 4.0793311      | 0.2532279  | torch.Size([2, 100, 256])   || torch.nn.modules.linear.Linear     | model.linear2  | weight   | -41.6590347    | 31.2311363     | -0.0053362     | 0.4427260  | torch.Size([60, 256])       || torch.nn.modules.linear.Linear     | model.linear2  | bias     | -0.4426649     | 0.3714900      | 0.0053294      | 0.0112585  | torch.Size([60])            || torch.nn.modules.linear.Linear     | model.linear2  | output   | -32.0065079    | 5.7881856      | 0.4558742      | 3.8736136  | torch.Size([2, 100, 60])    |
复制代码


解决方案:使用 int16 来量化这个敏感 linear 的 weight。


如果必须要求 linear input weight output 都是 int16 量化,怎么办呢?

2.知识基础

在 征程 6E/M 上,地平线 BPU 对 linear 支持的情况如下:


本文发布时是这样的



可以看到:input 和 weight 不能同时为 int16。

3.Linear input weight both int16

对于 linear input 和 weight 均需要 int16 量化的情况,可使用 broadcast mul sum 来替代验证,无需重训 float。


异同简介:broadcast_mul_sum_replace_linear 在 float 层面可以等价替换 linear,但在量化方式上存在区别:Linear weight 是 per channel 量化,weight 作为 mul 输入时,是 per tensor 量化。一般情况下:weight int8 perchannel 变成 per tensor int16,精度是正向优化。


替换方案:在 float 训练完成后替换,然后进行 calib+qat。


class SmallModel(nn.Module):    def __init__(self, linear2_weight, linear2_bias):        super(SmallModel, self).__init__()        # 第一个 Linear: 输入 [2, 100, 256] -> 输出 [2, 100, 256]        self.linear1 = nn.Linear(256, 256)        self.layernorm = nn.LayerNorm(256)  # 对最后一维进行归一化        self.relu = nn.ReLU()        # 第二个 Linear: 输入 [2, 100, 256] -> 输出 [2, 100, 60]        # self.linear2 = nn.Linear(256, 60)        self.linear2_weight = linear2_weight        self.linear2_bias = linear2_bias        # 第三个 Linear: 输入 [2, 100, 60] -> 输出 [2, 100, 60]        self.linear3 = nn.Linear(60, 60)        self.quant = QuantStub()        self.dequant = DeQuantStub()        self.quant_linear2_weight = QuantStub()        self.quant_linear2_bias = QuantStub()        def forward(self, x):        x = self.quant(x)        linear2_weight = self.quant_linear2_weight(self.linear2_weight)        linear2_bias = self.quant_linear2_bias(self.linear2_bias)        # 第一个 Linear        x = self.linear1(x)  # [2, 100, 256]        x = self.layernorm(x)  # [2, 100, 256]        x = self.relu(x)  # [2, 100, 256]                # 第二个 Linear        # x = self.linear2(x)  # [2, 100, 60]        # ===================================        # 使用 broadcast mul + sum 替换linear        # ===================================        # 广播乘法:输入 [2, 100, 256] 与权重 [60, 256] 进行广播        broadcast_mul = x.reshape(2, 100, 1, 256) * linear2_weight.reshape(1, 1, 60, 256)  # [2, 100, 60, 256]        # 按最后一个维度求和:sum 操作模拟线性层的加权求和        sum_output = broadcast_mul.sum(dim=-1)  # [2, 100, 60]        # 加上偏置        x = sum_output + linear2_bias  # [2, 100, 60]                # 第三个 Linear        x = self.linear3(x)        x = self.dequant(x)        return x
复制代码


broadcast mul sum 替换方案,均支持 int16。


注意事项:如果 mul 的输出 绝大多数 数值都在 0 附近 -> MSE 校准受异常值影响较大 -> 输出 scale 非常大 -> 0 附近的大量小数值被舍入成 0 -> sum 和发生巨大偏差。


影响范围:mul 后面跟着 sigmoid 或 add+sigmoid 时影响很大。


解决方案:mul 输出设置 fixed scale 为 7/32767,因为 sigmoid 并不需要太大的输入,而 mul 的输出分布需要小 scale。

4.全流程示例

从表中可以看到,在 linear 需要 int16 量化的场景,input/output int16 对应的 latency 最短,其次是 weight output int16 input int8,最差的是三者都需要 int16,针对这三种情况,下面分别提供完整的例子供参考。


信息描述



注意:非完全等价,仅作为参考

4.1 示例代码

import torchfrom horizon_plugin_pytorch import set_march, Marchset_march(March.NASH_M)from horizon_plugin_pytorch.quantization import prepare, set_fake_quantize, FakeQuantStatefrom horizon_plugin_pytorch.quantization import QuantStubfrom horizon_plugin_pytorch.quantization.hbdk4 import exportfrom horizon_plugin_pytorch.quantization.qconfig_template import calibration_8bit_weight_16bit_act_qconfig_setter, ModuleNameQconfigSetterfrom horizon_plugin_pytorch.quantization.qconfig import get_qconfig, MSEObserver, MinMaxObserverfrom horizon_plugin_pytorch.dtype import qint8, qint16from torch.quantization import DeQuantStubimport torch.nn as nnfrom horizon_plugin_pytorch.quantization import hbdk4 as hb4from hbdk4.compiler import convert, save, hbm_perf, visualize, compile
import torchimport torch.nn as nn
# 定义网络结构class SmallModel(nn.Module): def __init__(self, linear2_weight, linear2_bias): super(SmallModel, self).__init__() # 第一个 Linear: 输入 [2, 100, 256] -> 输出 [2, 100, 256] self.linear1 = nn.Linear(256, 256) self.layernorm = nn.LayerNorm(256) # 对最后一维进行归一化 self.relu = nn.ReLU() # 第二个 Linear: 输入 [2, 100, 256] -> 输出 [2, 100, 60] # self.linear2 = nn.Linear(256, 60) self.linear2_weight = linear2_weight self.linear2_bias = linear2_bias # 第三个 Linear: 输入 [2, 100, 60] -> 输出 [2, 100, 60] self.linear3 = nn.Linear(60, 60) self.quant = QuantStub() self.dequant = DeQuantStub() self.quant_linear2_weight = QuantStub() self.quant_linear2_bias = QuantStub() def forward(self, x): x = self.quant(x) linear2_weight = self.quant_linear2_weight(self.linear2_weight) linear2_bias = self.quant_linear2_bias(self.linear2_bias) # 第一个 Linear x = self.linear1(x) # [2, 100, 256] x = self.layernorm(x) # [2, 100, 256] x = self.relu(x) # [2, 100, 256] # 第二个 Linear # x = self.linear2(x) # [2, 100, 60] # =================================== # 使用 broadcast mul + sum 替换linear # =================================== # 广播乘法:输入 [2, 100, 256] 与权重 [60, 256] 进行广播 broadcast_mul = x.reshape(2, 100, 1, 256) * linear2_weight.reshape(1, 1, 60, 256) # [2, 100, 60, 256] # 按最后一个维度求和:sum 操作模拟线性层的加权求和 sum_output = broadcast_mul.sum(dim=-1) # [2, 100, 60] # 加上偏置 x = sum_output + linear2_bias # [2, 100, 60] # 第三个 Linear x = self.linear3(x) x = self.dequant(x) return x
float_ckpt_path = "model_path/float-checkpoint.ckpt" float_state_dict = torch.load(float_ckpt_path)# 遍历 OrderedDict,查找包含 "linear2" 的键for key, value in float_state_dict.items(): # if "linear2" in key: # print(f"Key: {key}, Value: {value.shape}") if key == "linear2.weight": linear2_weight = value if key == "linear2.bias": linear2_bias = value
# example_input = torch.randn(2, 100, 256)file_path = "random_data.pt"example_input = torch.load(file_path)model = SmallModel(linear2_weight, linear2_bias)missing_keys, unexpected_keys = model.load_state_dict(float_state_dict, strict=False)print("missing_keys & unexpected_keys:", missing_keys, '\n', unexpected_keys)
# 前向传播output = model(example_input)print("float输出数据:", output)torch.save(output, "model_path/6_model_float_output.pt")print("输入形状:", example_input.shape)print("输出形状:", output.shape)
# A global march indicating the target hardware version must be setted before prepare qat.set_march(March.NASH_M)
calib_model = prepare(model.eval(), example_input, qconfig_setter=( calibration_8bit_weight_16bit_act_qconfig_setter, ), )
calib_model.eval()set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)calib_model(example_input)
calib_model.eval() set_fake_quantize(calib_model, FakeQuantState.VALIDATION)calib_out = calib_model(example_input)print("calib输出数据:", calib_out)qat_bc = export(calib_model, example_input)hb_quantized_model = convert(qat_bc, March.NASH_M)
复制代码

4.2 比较替代方案的输出一致性

  • linear2 weight input output int16


float输出数据: tensor([[[-0.3016,  0.1338, -0.5251,  ..., -0.0551, -0.2093, -0.0308],         [-0.1969, -0.0131, -0.3287,  ...,  0.3234, -0.0869, -0.0637],         [-0.3056,  0.1478, -0.2673,  ...,  0.2355, -0.3487,  0.0134],         ...,         [-0.3990, -0.0389, -0.1686,  ..., -0.0046, -0.4131,  0.0482],         [-0.1059,  0.2431, -0.1886,  ...,  0.0787, -0.3454,  0.0231],         [-0.2134, -0.1071, -0.0575,  ...,  0.3434, -0.1661,  0.2248]]],       grad_fn=<ViewBackward0>)       calib输出数据: tensor([[[-0.3038,  0.1370, -0.5269,  ..., -0.0571, -0.2111, -0.0296],         [-0.1975, -0.0111, -0.3280,  ...,  0.3215, -0.0884, -0.0637],         [-0.3052,  0.1488, -0.2677,  ...,  0.2348, -0.3479,  0.0132],         ...,         [-0.3988, -0.0393, -0.1662,  ..., -0.0055, -0.4117,  0.0484],         [-0.1058,  0.2442, -0.1890,  ...,  0.0780, -0.3447,  0.0240],         [-0.2142, -0.1061, -0.0587,  ...,  0.3422, -0.1657,  0.2255]]],       grad_fn=<ViewBackward0>)
复制代码


  • broadcast mul sum int16


float输出数据: tensor([[[-0.3016,  0.1338, -0.5251,  ..., -0.0551, -0.2093, -0.0308],         [-0.1969, -0.0131, -0.3287,  ...,  0.3234, -0.0869, -0.0637],         [-0.3056,  0.1478, -0.2673,  ...,  0.2355, -0.3487,  0.0134],         ...,         [-0.3990, -0.0389, -0.1686,  ..., -0.0046, -0.4131,  0.0482],         [-0.1059,  0.2431, -0.1886,  ...,  0.0787, -0.3454,  0.0231],         [-0.2134, -0.1071, -0.0575,  ...,  0.3434, -0.1661,  0.2248]]],       grad_fn=<ViewBackward0>)calib输出数据: tensor([[[-0.3038,  0.1370, -0.5269,  ..., -0.0571, -0.2111, -0.0296],         [-0.1975, -0.0111, -0.3280,  ...,  0.3215, -0.0884, -0.0637],         [-0.3051,  0.1487, -0.2678,  ...,  0.2349, -0.3478,  0.0132],         ...,         [-0.3988, -0.0392, -0.1662,  ..., -0.0055, -0.4117,  0.0484],         [-0.1058,  0.2442, -0.1890,  ...,  0.0780, -0.3447,  0.0240],         [-0.2142, -0.1061, -0.0586,  ...,  0.3423, -0.1657,  0.2255]]],       grad_fn=<ViewBackward0>)
复制代码


用户头像

还未添加个人签名 2021-03-11 加入

还未添加个人简介

评论

发布
暂无评论
征程 6 J6E/M linear 双int16量化支持替代方案_自动驾驶;_地平线开发者_InfoQ写作社区