from hbdk4.compiler.torch import export
from hbdk4.compiler import statistics, save, load,visualize,compile
from hbdk4.compiler.march import March
from hbdk4.compiler import convert,hbm_perf
#使用load加载伪量化bc
model=load("qat.bc")
#使用visualize生成onnx可视化bc
visualize(model, "qat_ori.onnx")
func = model.functions[0]
#batch拆分,此过程为batch nv12输入的必须操作
batch_input = ["_input_0"]
for input in func.inputs[::-1]:
for name in batch_input[::-1]:
if name in input.name:
input.insert_split(dim=0)
#可视化已做完batch拆分的bc
visualize(model, "qat_split_batch.onnx")
#插入预处理节点
func = model.functions[0]
#pyramid_input为模型中NV12输入的name,可以通过可视化qat_split_batch.onnx获取
#ddr_input为模型中ddr输入的name,可以通过可视化qat_split_batch.onnx获取
pyramid_input = ['_input_0_0','_input_0_1','_input_0_2','_input_0_3','_input_0_4','_input_0_5'] # 部署时数据来源于pyramid的输入节点名称列表
ddr_input = "_input_1" # 部署时数据来源于ddr的输入节点名称列表
#插入nv12节点
for input in func.inputs[::-1]:
print(input.name)
if input.name in pyramid_input:
#pyramid&resizer 只支持 NHWC 的 input layout
input.insert_transpose(permutes=[0, 3, 1, 2])
# 插入前处理节点,这里模型训练是YUV444图像,所以mode配置为None
input.insert_image_preprocess(mode=None, divisor=1, mean=[128, 128, 128], std=[128, 128, 128])
input.insert_image_convert("nv12")
print("-----insert nv12 success-----")
#插入resizer节点
#for input in func.inputs[::-1]:
#if input.name in resizer_input:
# pyramid&resizer 只支持 NHWC 的 input layout
#node = input.insert_transpose(permutes=[0, 3, 1, 2])
# 插入前处理节点,具体可参考下一节的说明
#node = input.insert_image_preprocess(mode=None, divisor=1, mean=[128, 128, 128], std=[128, 128, 128])
#node.insert_roi_resize("nv12")
#插入transpose节点
for input in func.inputs[::1]:
if input.name == ddr_input:
#layerout变换:NCHW->NHWC
input.insert_transpose(permutes=[0, 2, 3, 1])
#可视化插入预处理节点后的模型
visualize(model, "qat_preprocess.onnx")
#将插入预处理节点后hbir保存为bc
save(model,"qat_preprocess.bc")
#将伪量化bc convert为定点bc
#配置advice参数显示算子相关信息
quantized_model=convert(model,'nash-e',advice=True,advice_path='./')
#可视化定点bc
visualize(quantized_model, "quantized_ori.onnx")
#删除量化/反量化节点
# convert后的bc的首尾部默认包含量化反量化节点,可以进行手工删除
node_type_mapping = {
"qnt.quantize": "Quantize",
"qnt.dequantize": "Dequantize",
"hbir.transpose": "Transpose",
"hbtl.call::quant::qcast": "Quantize",
"hbtl.call::quant::dcast": "Dequantize",
"hbtl.call::native::Transpose": "Transpose",
"hbir.cast_type": "Cast",
"hbir.reshape": "Reshape",
"hbtl.call::native::Cast": "Cast",
"hbtl.call::native::Reshape": "Reshape",
}
def get_type_for_hbtl_call(attached_op):
schema = attached_op.schema
node_type = attached_op.type + "::" + \
schema.namespace + "::" + schema.signature
return node_type
def remove_op(func, op_type=None, op_name=None):
for loc in func.inputs + func.outputs:
if not loc.is_removable[0]:
continue
attached_op = loc.get_attached_op[0]
removed = None
# 目前hbir模型中的op name格式还未完全确定,暂建议使用op type来删除节点
attached_op_name = attached_op.name
if op_name and attached_op.name in op_name:
removed, diagnostic = loc.remove_attached_op()
elif op_type and attached_op.type in node_type_mapping.keys() \
and node_type_mapping[attached_op.type] in op_type:
removed, diagnostic = loc.remove_attached_op()
elif attached_op.type == "hbtl.call":
# 由于同一type的op在后端可能对应多种实现,因此采用“签名”的方式确认具体类型
node_type = get_type_for_hbtl_call(attached_op)
if op_type and node_type in node_type_mapping.keys() \
and node_type_mapping[node_type] in op_type:
removed, diagnostic = loc.remove_attached_op()
if removed is True:
print(f'Remove node', op_type, "successfully")
if removed is False:
raise ValueError(f'Remove node type', op_type,
f"Failed when deleting {attached_op.name} operator,"
f"error: {diagnostic}")
func = quantized_model[0]
# 删除reshape节点
#remove_op(func, op_type="Reshape")
#remove_op(func, op_type="Cast")
# 删除量化反量化节点
remove_op(func, op_type="Dequantize")
remove_op(func, op_type="Quantize")
# 删除max后的reshape节点
#remove_op(func, op_type="Reshape")
# 删除Transpose节点
#remove_op(func, op_type="Transpose")
print("-----remove_quant_dequant OK-----")
save(quantized_model,"quantized_modified.bc")
visualize(quantized_model, "quantized_remove_dequa.onnx")
#使用compile编译定点bc为hbm
print("-----start to compile model-----")
#
params = {'jobs': 48, 'balance': 100, 'progress_bar': True,
'opt': 2,'debug':True}
compile(
quantized_bc,
march="nash-e",
path="model.hbm",
**params
)
print("-----end to compile model-----")
#模型性能预估
print("-----start to perf model-----")
save_path="./perf"
hbm_perf('model.hbm',save_path)
评论