一、前言
在端侧部署时(如在移动设备、嵌入式芯片上),为了加速模型推理、减少功耗和资源开销,往往会将某些计算复杂的函数(如 exp、log、tanh、sigmoid、softmax 等)转为查表操作。查表算子在转成定点计算时不可避免地会出现误差,此时就需要定位引起精度下降的具体算子以及对其进行针对性的优化。
本文将讲述在使用地平线 QAT 链路基于 J6 系列平台进行模型部署时,对查表算子进行精度调优的相关手段,主要包括以下内容:
如何确定是定点查表导致的误差?
如何确定具体的定点查表算子导致的误差?
二、如何确定是定点查表导致的误差
当 QAT 模型的精度处于正常状态,然而 QAT 模型 export 出来的qat.bc
文件精度却出现异常情况时,并且在这种情况下,我们已经对导出 qat.bc
文件的整个流程进行了全面且细致的检验,确认该流程不存在任何错误。那么在这样的前提条件之下,我们此时就可以开始着手验证是否是由于查表算子的因素导致了精度下降。
horizon_plugin_pytorch
提供了 api 来辅助进行查表算子精度的验证,具体思路是将 qat model 中所有的查表算子转成定点,然后在验证集上进行精度的评测,如果和 qat.bc 的现象一样都出现了比较严重的精度下降问题,那么就说明是因为查表算子导致的误差。
以下是horizon_plugin_pytorch
提供的 api 的使用示例,如下所示:
import torch
from horizon_plugin_pytorch.quantization import prepare,set_fake_quantize,FakeQuantState
import copy
qat_model = prepare(
copy.deepcopy(float_model),
example_inputs=example_input,
qconfig_setter=default_calibration_qconfig_setter,)
print("--"*20+"Prepare qat model success"+"--"*20)
state_dict = torch.load(ckpt_path)
qat_model.load_state_dict(new_state_dict,strict=True)
print("--"*20+"Load qat ckpt success"+"--"*20)
qat_model.eval()
set_fake_quantize(qat_model, FakeQuantState.VALIDATION)
#将qat model中所有的查表算子转成定点
qat_lut_model=copy.deepcopy(qat_model)
from horizon_plugin_pytorch.nn.qat.segment_lut import QuantizedQATSegmentLUT
QuantizedQATSegmentLUT.convert_segment_lut(qat_lut_model)
#评测查表转定点的qat_lut model的精度
evaluate(qat_lut_mode,val_dataloader, .....)
#如果相对于qat model精度下降比较严重,那么就说明是查表算子导致的问题
复制代码
三、如何确定引起误差的查表算子?
在确定是查表算子导致的精度误差后,我们需要进一步确定是哪些查表算子导致的误差。一般来说,模型中会包含多个多种查表算子。
具体方法是结合 QAT 精度 debug 工具horizon_plugin_profile
来做 qat_model 和 qat_lut_model(查表转定点的 qat model)的精度 debug,然后根据敏感度来确定具体的查表算子。
以下是 QAT 精度 debug 工具的使用示例,如下所示:
from horizon_plugin_profiler import QuantAnalysis, ModelProfiler
qa = QuantAnalysis(
baseline_model=qat_model,
analysis_model=qat_lut_model,
analysis_model_type="fake_quant",
device_ids=0, # GPU index,若不指定则在 CPU 上
out_dir=output_dir)
qa.auto_find_bad_case(data_generator=val_dataloader,metric="L1")
qa.run()
qa.compare_per_layer()
qa.sensitivity(metric="L1")
复制代码
debug 工具运行完成后,在out_dir
会生成系列产物,这里我们主要关注逐层相似度和输出敏感度,compare_per_layer_out.csv
和output_xxxx_L1_sensitive_ops.txt
文件。
在获得敏感度 txt 文件后,我们就根据敏感度顺序逐步来确认引起误差的算子,具体思路如下:
将 qat mode 中所有的查表算子转成定点;
然后将敏感度靠前的差表算子回退到浮点进行精度评测;
如果在将某个/类别查表算子回退到浮点以后,精度指标与 qat model 区别不大,那么就说明是这个/类查表导致的精度下降。
下面为具体的操作代码:
import torch
from horizon_plugin_pytorch.quantization import prepare,set_fake_quantize,FakeQuantState
import copy
qat_model = prepare(
copy.deepcopy(float_model),
example_inputs=example_input,
qconfig_setter=default_calibration_qconfig_setter,)
print("--"*20+"Prepare qat model success"+"--"*20)
state_dict = torch.load(ckpt_path)
qat_model.load_state_dict(new_state_dict,strict=True)
print("--"*20+"Load qat ckpt success"+"--"*20)
qat_model.eval()
set_fake_quantize(qat_model, FakeQuantState.VALIDATION)
#将qat model中所有的查表算子转成定点
qat_lut_model=copy.deepcopy(qat_model)
from horizon_plugin_pytorch.nn.qat.segment_lut import QuantizedQATSegmentLUT
QuantizedQATSegmentLUT.convert_segment_lut(qat_lut_model)
#将敏感度靠前的算子回退到浮点
#"decoder.decoder._generated_log_1.log"为算子名称
qat_lut_model.get_submodule("decoder.decoder._generated_log_1.log").quantized_forward = False
....
#评测查表转定点的qat_lut model的精度
evaluate(qat_lut_mode,val_dataloader, .....)
#如果相对于qat model精度下降比较严重,那么就说明是查表算子导致的问题
复制代码
这边补充说明两点:
在下一篇文章中,我们将演示在定位到具体算子后,如何进行精度调优,敬请期待!!!
评论