量化训练及精度调优经验分享
本文提纲:
fx 和 eager 两种量化训练方式介绍
量化训练的流程介绍:以 mmdet 的 yolov3 为例
常用的精度调优 debug 工具介绍
案例分析:模型精度调优经验分享
第一部分:fx 和 eager 两种量化训练方式介绍
首先介绍一下量化训练的原理。
上图为单个神经元的计算,计算形式是加权求和,再经过非线性激活后得到输出,这个输出又可以作为下一个神经元的输入继续运输,所以神经网络的基础运算是矩阵的乘法。如果神经元的计算全部采用 float32 的形式,模型的内存占用和数据搬运都会很占资源。如果用 int8 替换 float32,内存的搬运效率能提高 75%,充分展示了量化的有效性。由于两个 int8 相乘会超出 int8 的表示范围,为了防止溢出,累加器使用 int32 类型的,累加后的结果会再次 requantized 到 int8;
量化的目标就是在尽可能不影响模型精度的情况下降低模型的功耗,实现模型压缩效果,常见的量化方式有后量化训练 PTQ 和量化感知训练 QAT。
量化感知训练其实是一种伪量化的过程,即在训练过程中模拟浮点转定点的量化过程,数据虽然都是表示为 float32,但实际的值会间隔地受到量化参数的限制。具体方法是在某些 op 前插入伪量化节点(fake quantization nodes),伪量化节点有两个作用:
1.在训练时,用以统计流经该 op 的数据的最大最小值,便于在部署量化模型时对节点进行量化
2.伪量化节点参与模型训练的前向推理过程,因此会模型训练中导入了量化损失,但伪量化节点是不参与梯度更新过程的。
上图是模型学习量化损失的示意图, 正常的量化流程是 quantize->mul(int)->dequantize,而伪量化是对原先的 float 先 quantize 到 int,再 dequantize 到 float,这个步骤用于模拟量化过程中 round 操作所带来的误差,用这个误差再去进行前向运算。上图可以比较直观的表示引起误差的原因,从左到右数第 4 个黑点表示一个浮点数,quantize 后映射到 253,dequantize 后取到了第 5 个黑点,这就引起了误差。
地平线基于 PyTorch 开发的 horizon_plugin_pytorch 量化训练工具,同时支持 Eager 和 fx 两种模式。
eager 模式的使用方式建议参考用户手册 -4.2 量化感知训练章节(4.2.2。 快速上手中有完整的快速上手示例,各使用阶段注意事项建议参考 4.2.3。 使用指南)。fx 模式的相关 API 介绍请参考用户手册 -4.2.3.4.2。 主要接口参数说明章节
第二部分:量化训练的流程介绍:以 mmdet 的 yolov3 为例
QAT 流程介绍
准备好浮点模型,加载训好的浮点权重
设置 BPU 架构
算子融合(eager 模式需要,fx 可省略)
设置量化配置
整个 model 使用默认的 qconfig
模型的输出,配置高精度输出
det 模型 head 输出的 loss 损失函数的 qconfig 设置为 None
将浮点模型转换为 qat 模型(示例使用 eager 模式)
开始 qat 训练
可以复用浮点的 train_detector,替换 model 即可
qat 模型转定点(需要 load 训练好的 qat 模型权重)
deploy_model 和 example_input 准备
Trace 模型构建静态 graph,进行编译
eval()使 bn、dropout 等处于正确的状态
编译只能在 cpu 上做
check_model 用于检查算子是否能全部跑在 bpu 上,建议提前检查
如果 qat 精度不达标,如何插入 calibration?
伪量化节点(fake quantize)的三种状态:
CALIBRATION 模式:即不进行伪量化操作,仅观测算子输入输出统计量,更新 scale
QAT 模式:观测统计量并进行伪量化操作。
VALIDATION 模式:不会观测统计量,仅进行伪量化操作。
以下常见误操作会导致一些异常现象:
calibration 之前模型设置为 train()的状态,且未使用**
set_fake_quantize
**,等于是在跑 QAT 训练;calibration 之前模型设置为 eval()的状态,且未使用**
set_fake_quantize
**,会导致 scale 一直处于初始状态,全为 1,calib 不起作用。calibration 之前模型设置为 eval()的状态,且正确使用了**
set_fake_quantize
**,但是在这之后又设置了一遍 model.eval(),这将导致 fake_quant 未处于训练状态,scale 一直处于初始状态,全为 1;
对 mobilenet_v2 模型做 qat 训练的设置
量化节点设置
关键代码:
算子融合
[7.5.5. 算子融合 — Horizon Open Explorer](https://developer.horizon.ai/api/v1/fileData/horizon_j5_open_explorer_cn_doc/plugin/source/advanced_content/op_fusion.html?highlight=算子融合 算子 融合 #)
举个例子:mmcv/cnn/bricks/conv_module.py
eager 方案麻烦的是,基本每个模块都要手动去设置算子融合
反量化节点设置
mmdetection-master/mmdet/models/dense_heads/yolo_head.py
关键代码:
第三部分:常用的精度调优 debug 工具介绍
工具:**集成接口、量化配置检查、模型可视化、相似度对比、统计量、分步量化、异构模型部署 device 检查**
第四部分:模型精度调优分享
模型精度调优时常遇到的问题:
calib 模型的精度和 float 对齐,quantized 模型的精度损失较大
正常情况下,calib/qat 模型的精度和 quantized 模型的精度损失很小(1%), 如果偏差过大,可能是 calib/qat 的流程不对。
原因:calib 模型伪量化节点的状态不正确,导致 calib 阶段,测试的是 float 模型的精度,而 quantized 阶段,测试的是 calib 模型的精度,所以精度损失本质上还是量化精度的损失。
如何避免:
正确设置 calib 训练和评测时的伪量化节点状态。
让客户在 calib 的基础上,做 qat, 评测 qat 模型的精度。(客户的数据量大,qat 时间太长,一直没有选择 qat,导致这个问题被暴露出来了)
如何设置正确的 calib 伪量化节点的状态?(fx 和 eager 都是一样的)
http://model.aidi.hobot.cc/api/docs/horizon_plugin_pytorch/latest/html/user_guide/calibration.html
注意:16 行的 train 在评测时,也要设置 FakeQuantState.VALIDATION,不然 scale 不生效,评测的指标也不对
常见问题:
数据校准之前模型设置为 train()的状态,且未使用**
set_fake_quantize
**,等于 caib 阶段是在跑 QAT 训练;校准的评测阶段,未设置伪量化节点的模式为 VALIDATION, 实际评测的是 float 模型;
总结 2: 如果做 calib,一定要仔细检查伪量化节点状态和模型状态是否正确,避免不符合预期的结果
2.当量化精度损失超过大,如何调优?
使用 model_profiler() 这个集成接口,生成压缩包。
检查是否配置高精度输出、是否存在未融合的算子、是否共享 op、是否算子分布过大 int8 兜不住?
注意:使用 debug 集成接口时,要保证浮点模型训练到位,并传入真实数据
3.多任务模型的精度调优建议
qat 调优策略和常规模型一样,ptq+qat
如果只有一个 head 精度有损失,可以固定其他部分,单独使用这个 head 的数据做 calib
4.calib 和 qat 流程的正确衔接
calib:
qat:
5.检查 conv 高精度输出
方式 1:查看 qconfig_info.txt,重点关注 DeQuantStub 附近的 conv 是不是 float32 输出
qconfig_info.txt
方式 2:打印 qat_model 的最后一层,查看该层是否有 (activation_post_process): FakeQuantize
高精度的 conv:
int8 的 conv
6.检查共享 op
打开 qconfig_info.txt,后面标有(n)的就是共享的
特殊情况:layernorm 在 QAT 阶段是多个小量化算子拼接而成,module 的重复调用,也会产生大量 op 共享的问题
解决办法: 将 layernorm 替换为 batchnorm,测试了 float 精度,没有下降。
7.检查未融合的算子
打开 qconfig_info.txt,全局搜 BatchNorm2d 和 ReLU,如果前面有 conv,那就是没做算子融合
可以融合的算子:
conv+bn
conv+relu
conv+add
conv+bn+relu
conv+bn+add
conv+bn+relu+add
8.检查数据分布特别大的算子
打开 float 模型的统计量分布,一般是 model0_statistic.txt
有两个表,第一个表是按模型结构排列的;第二个表是按数据分布范围排列的
拖到第二个表,看前几行是那些 op
可以看到很多 conv 的分布很异常,使用的是 int8 量化
解决办法:
检查这些 conv 后面是否有 bn,添加 bn 后,数据能收敛一些
如果结构上已经加了 bn,数据分布还大,可以配置 int16 量化
int16 调这两个接口,default_qat_16bit_fake_quant_qconfig 和 default_calib_16bit_fake_quant_qconfig
中间算子的写法和高精度输出类似 model.xx.qconfig = default_qat_16bit_fake_quant_qconfig ()
评论