写点什么

无需 nms,onnxruntime20 行代码玩转 RT-DETR

  • 2023-05-05
    山东
  • 本文字数:2412 字

    阅读完需:约 8 分钟

【前言】 RT-DETR 是由百度近期推出的 DETR-liked 目标检测器,该检测器由 HGNetv2、混合编码器和带有辅助预测头的 Transformer 编码器组成,整体结构如下所示。

本文将采用 RT-DETR 两种不同风格的 onnx 格式,使用 onnxruntime20 行代码,无需 nms 操作即可实现简易部署推理.

一、原生 onnx+ort 推理方式

使用以下命令抽取出模型配置文件和模型参数文件:

python tools/export_model.py -c configs/rtdetr/rtdetr_hgnetv2_l_6x_coco.yml -o weights=https://bj.bcebos.com/v1/paddledet/models/rtdetr_hgnetv2_l_6x_coco.pdparams trt=True --output_dir=output_inference
复制代码

转化模型为 onnx 形式:

paddle2onnx --model_dir=./output_inference/rtdetr_hgnetv2_l_6x_coco/ --model_filename model.pdmodel  --params_filename model.pdiparams --opset_version 16 --save_file rtdetr_hgnetv2_l_6x_coco.onnx
复制代码

抽取后的 onnx 可视化如下:

可以看到,除了图像的输入,还有另外两个输入头,其中,im_shape 指原输入图像的尺寸,scale_factor 指静态图尺度/原输入图像尺度,其实就是缩放的系数。我们将 batch_size 固定为 1,裁减掉不需要使用到的算子:

python -m paddle2onnx.optimize --input_model rtdetr_hgnetv2_l_6x_coco.onnx --output_model rtdetr_hgnetv2_l_6x_coco_sim.onnx --input_shape_dict "{'image':[1,3,640,640]}
复制代码

使用简化后的 onnx 模型进行推理:

import onnxruntime as rtimport cv2import numpy as np
sess = rt.InferenceSession("/home/aistudio/PaddleDetection/rtdetr_hgnetv2_l_6x_coco_sim.onnx")img = cv2.imread("../000283.jpg")org_img = imgim_shape = np.array([[float(img.shape[0]), float(img.shape[1])]]).astype('float32')img = cv2.resize(img, (640,640))scale_factor = np.array([[float(640/img.shape[0]), float(640/img.shape[1])]]).astype('float32')img = img.astype(np.float32) / 255.0input_img = np.transpose(img, [2, 0, 1])image = input_img[np.newaxis, :, :, :]result = sess.run(["reshape2_83.tmp_0","tile_3.tmp_0"], {'im_shape': im_shape, 'image': image, 'scale_factor': scale_factor})for value in result[0]:    if value[1] > 0.5:        cv2.rectangle(org_img, (int(value[2]), int(value[3])), (int(value[4]), int(value[5])), (255,0,0), 2)        cv2.putText(org_img, str(int(value[0]))+": "+str(value[1]), (int(value[2]), int(value[3])), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1)cv2.imwrite("../result.png", org_img)
复制代码

推理结果:


二、野生 onnx+ort 推理方式

其实通过官方 onnx 模型的格式可以看出,官方已经将所有后处理步骤写入到模型中,此时不需要额外添加后处理代码,是一种比较省心的方式。但对于有强迫症的笔者而言,对于三个输入头的模型实在是看着别扭,因此我更偏向于下面的这种推理方式。同样是抽取官方模型,但此时我们将后处理的所有操作全部摘除,只保留原模型参数:将模型的 exclude_post_process 设置为 True,然后使用同样的代码进行转化:

python tools/export_model.py -c configs/rtdetr/rtdetr_hgnetv2_l_6x_coco.yml -o weights=https://bj.bcebos.com/v1/paddledet/models/rtdetr_hgnetv2_l_6x_coco.pdparams trt=True --output_dir=output_inference_sim
复制代码

将转化后的 pdmodel 进行可视化:

左边为未摘除后处理的 pdmodel,右边为摘除后的 pdmodel,以分类支路为例,我们可以看到,分类支路从 Sigmoid 开始,已经 Sigmoid 和后面的 Children Node 摘除干净,那么可以转化为 onnx 文件,步骤与上面一致。

使用转化后的 onnx 文件进行推理:

import onnxruntime as rtimport cv2import numpy as np
sess = rt.InferenceSession("rtdetr_hgnetv2_l_6x_coco_sim2.onnx")img = cv2.imread("../000283.jpg")img = cv2.resize(img, (640,640))image = img.astype(np.float32) / 255.0input_img = np.transpose(image, [2, 0, 1])image = input_img[np.newaxis, :, :, :]results = sess.run(['scores',  'boxes'], {'image': image})scores, boxes = [o[0] for o in results]index = scores.max(-1)boxes, scores = boxes[index>0.5] * 640, scores[index>0.5]labels = scores.argmax(-1)scores = scores.max(-1)for box, score, label in zip(boxes, scores, labels):    cx, cy, w, h = int(box[0]), int(box[1]), int(box[2]), int(box[3])    cv2.rectangle(img, (cx-int(w/2), cy-int(h/2)), (cx+int(w/2), cy+int(h/2)), (0, 255, 255), 2)    cv2.putText(img, f'{label} : {score:.2f}', (cx-int(w/2), cy-int(h/2)-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 1)cv2.imwrite('../result.jpg', img)
复制代码

推理结果:

【结尾】 本文介绍了 RT-DETR 两种风格的 onnx 格式和推理方式,不管哪种风格,精度无任何差别,至于是使用哪款,纯凭个人爱好,下一期会出一篇 CNN-liked 代表 YOLOv8 和 DETR-liked 代表 RT-DETR 在 C++部署上的性能差异,在本文结尾先附上本文使用的两个 onnx 模型。

链接:https://pan.baidu.com/s/1AkG3uvILNQhQXeE7z8rYQw ,提取码:pogg 链接:https://pan.baidu.com/s/193Yt99CspP8vZ6ynWOl-ag ,提取码:pogg

本文源自:“ GiantPandaCV”公众号


卡奥斯开源社区是为开发者提供便捷高效的开发服务和可持续分享、交流的 IT 前沿阵地,包含技术文章、群组、互动问答、在线学习、大赛活动、开发者平台、OpenAPI 平台、低代码平台、开源项目等服务,社区使命是让每一个知识工人成就不凡。

 官网链接:Openlab.cosmoplat—打造工业互联网顶级开源社区

用户头像

打造工业互联网顶级开源社区 2023-02-10 加入

卡奥斯开源社区是为开发者提供便捷高效的开发服务和可持续分享、交流的IT前沿阵地,包含技术文章、群组、互动问答、在线学习、开发者平台、OpenAPI平台、低代码平台、开源项目、大赛活动等服务。

评论

发布
暂无评论
无需nms,onnxruntime20行代码玩转RT-DETR_Openlab_cosmoplat_InfoQ写作社区