带你读 AI 论文:基于 Transformer 的直线段检测
摘要:本文提出了一种基于 Transformer 的端到端的线段检测模型。采用多尺度的 Encoder/Decoder 算法,可以得到比较准确的线端点坐标。作者直接用预测的线段端点和 Ground truth 的端点的距离作为目标函数,可以更好的对线段端点坐标进行回归。
本文分享自华为云社区《论文解读系列十七:基于Transformer的直线段检测》,作者:cver。
1、文章摘要
传统的形态学线段检测首先要对图像进行边缘检测,然后进行后处理得到线段的检测结果。一般的深度学习方法,首先要得到线段端点和线的热力图特征,然后进行融合处理得到线的检测结果。作者提出了一种新的基于 Transformer 的方法,无需进行边缘检测、也无需端点和线的热力图特征,端到端的直接得到线段的检测结果,也即线段的端点坐标。
线段检测属于目标检测的范畴,本文提出的线段检测模型 LETR 是在 DETR(End-to-End Object Detection with Transformers)的基础上的扩展,区别就是 Decoder 在最后预测和回归的时候,一个是回归的 box 的中心点、宽、高值,一个是回归的线的端点坐标。
因此,接下来首先介绍一下 DETR 是如何利用 Transformer 进行目标检测的。之后重点介绍一下 LETR 独有的一些内容。
2、如何利用 Transformer 进行目标检测(DETR)
图 1. DETR 模型结构
上图是 DETR 的模型结构。DETR 首先利用一个 CNN 的 backbone 提取图像的 features,编码之后输入 Transformer 模型得到 N 个预测的 box,然后利用 FFN 进行分类和坐标回归,这一部分和传统的目标检测类似,之后把 N 个预测的 box 和 M 个真实的 box 进行二分匹配(N>M,多出的为空类,即没有物体,坐标值直接设置为 0)。利用匹配结果和匹配的 loss 更新权重参数,得到最终的 box 的检测结果和类别。这里有几个关键点:
首先是图像特征的序列化和编码。
CNN-backbone 输出的特征的维度为 C*H*W,首先用 1*1 的 conv 进行降维,将 channel 从 C 压缩到 d, 得到 d*H*W 的特征图。之后合并 H、W 两个维度,特征图的维度变为 d*HW。序列化的特征图丢失了原图的位置信息,因此需要再加上 position encoding 特征,得到最终序列化编码的特征。
然后是 Transformer 的 Decoder
目标检测的 Transformer 的 Decoder 是一次处理全部的 Decoder 输入,也即 objectqueries,和原始的 Transformer 从左到右一个一个输出略有不同。
另外一点 Decoder 的输入是随机初始化的,并且是可以训练更新的。
二分匹配
Transformer 的 Decoder 输出了 N 个 object proposal ,我们并不知道它和真实的 Ground truth 的对应关系,因此需要经二分图匹配,采用的是匈牙利算法,得到一个使的匹配 loss 最小的匹配。匹配 loss 如下:
得到最终匹配后,利用这个 loss 和分类 loss 更新参数。
3、LETR 模型结构
图 2. LETR 模型结构
Transformer 的结构主要包括 Encoder、Decoder 和 FFN。每个 Encoder 包含一个 self-attention 和 feed-forward 两个子层。Decoder 除了 self-attention 和 feed-forward 还包含 cross-attention。注意力机制:注意力机制和原始的 Transformer 类似,唯一的不同就是 Decoder 的 cross-attention,上文已经做了介绍,就不再赘述。
Coarse-to-Fine 策略
从上图中可以看出 LETR 包含了两个 Transformer。作者称此为 a multi-scale Encoder/Decoder strategy,两个 Transformer 分别称之为 Coarse Encoder/Decoder,FineEncoder/Decoder。也就是先用 CNN backbone 深层的小尺度的 feature map(ResNet 的 conv5,feature map 的尺寸为原图尺寸的 1/32,通道数为 2048) 训练一个 Transformer,即 Coarse Encoder/Decoder,得到粗粒度的线段的特征(训练的时候固定 FineEncoder/Decoder,只更新 Coarse Encoder/Decoder 的参数)。然后把 Coarse Decoder 的输出作为 Fine Decoder 的输入,再训练一个 Transformer,即 FineEncoder/Decoder。Fine Encoder 的输入是 CNN backbone 浅层的 feature map(ResNet 的 conv4,feature map 的尺寸为原图尺寸的 1/16,通道数为 1024),比深层的 feature map 具有更大的维度,可以更好的利用图像的高分辨率信息。
注:CNN 的 backbone 深层和浅层的 feature map 特征都需要先通过 1*1 的卷积把通道数都降到 256 维,再作为 Transformer 的输入
二分匹配
和 DETR 一样, 利用 fine Decoder 的 N 个输出进行分类和回归,得到 N 个线段的预测结果。但是我们并不知道 N 个预测结果和 M 个真实的线段的对应关系,并且 N 还要大于 M。这个时候我们就要进行二分匹配。所谓的二分匹配就是找到一个对应关系,使得匹配 loss 最小,因此我们需要给出匹配的 loss,和上面 DERT 的表达式一样,只不过这一项略有不同,一个是 GIou 一个是线段的端点距离。
4、模型测试结果
模型在 Wireframe 和 YorkUrban 数据集上达到了 state-of–the-arts。
图 3. 线段检测方法效果对比
图 4、线段检测方法在两种数据集上的性能指标对比(Table 1);线段检测方法的 PR 曲线(Figure 6)
版权声明: 本文为 InfoQ 作者【华为云开发者社区】的原创文章。
原文链接:【http://xie.infoq.cn/article/1d516ddad0f448cfe9ee6829b】。文章转载请联系作者。
评论