MAE 自监督算法介绍和基于 EasyCV 的复现
作者:贺弘、谦言、临在
导言
自监督学习(Self-Supervised Learning)能利用大量无标注的数据进行表征学习,然后在特定下游任务上对参数进行微调。通过这样的方式,能够在较少有标注数据上取得优于有监督学习方法的精度。近年来,自监督学习受到了越来越多的关注,如 Yann Lecun 也在 AAAI 上讲 Self-Supervised Learning 是未来的大势所趋。在 CV 领域涌现了如 SwAV、MOCO、DINO、MoBY 等一系列工作。MAE 是 kaiming 继 MOCO 之后在自监督学习领域的又一力作。首先,本文会对 MAE 进行解读,然后基于 EasyCV 库的精度复现过程及其中遇到的一些问题作出解答。
概述
MAE 的做法很简单:随机 mask 掉图片中的一些 patch,然后通过模型去重建这些丢失的区域。包括两个核心的设计:1.非对称编码-解码结构 2.用较高的掩码率(75%)。通过这两个设计 MAE 在预训练过程中可以取得 3 倍以上的训练速度和更高的精度,如 ViT-Huge 能够通过 ImageNet-1K 数据上取得 87.8%的准确率。
模型拆解
MAE 属于自编码器(AutoEncoder)的一种,由编码器和解码器两个部分组成。类似于常见的自编码器,MAE 会先通过编码器将图片 patch 映射到隐空间。然后,基于解码器将隐空间上的特征变量重构成图片 patch。和常见自编码器的区别是非对称的编码解码结构。这个非对称性主要体现在以下两点:
轻量化的解码器结构在编码器阶段,仅将未被 mask 掉的图片 patch 作为输入。
在解码器阶段会将编码器输出的隐变量和 mask token 共同作为输入去重建完成的图片。
掩码策略
首先,直接采用 ViT 的做法将图片分成不重叠的 patch(如 vit-b 会将图片划分成 16x16 的图像块),然后通过均匀采样策略对这些 patch 进行采样,并丢弃未被选中的部分。MAE 所采用的掩码策略有如下两个特点:
1.在算法中,使用了 75%的 masking ratio 来丢弃图片 patch。作者指出,通过 high masking ratio 可以有效减少输入的冗余程度,使重建任务不能够通过简单的参考邻近 patch 来完成。文中,也通过实验证明了这一观点。
关于 Masking ratio 的实验是 MAE 最精彩的一部分,随着 mask ratio 的增加,fine-tuning 和 linear proing 的精度逐渐攀升,甚至到 75%还没有下降,这一点打破了 BERT(15%)、BEiT(40%)的做法,进一步将 mask 预训练方式在 NLP 领域的成功在 CV 领域实现复制。
2.采用了均匀采样策略可以有效的避免 potential center bias(丢弃掉的 patch 都靠近图片中心)。对 mask 策略的消去实验如下表所示。
编码器
MAE encoder 采用的是 ViT 结构。在对图像 patch 进行采样后,仅保留 25%未被 mask 的图像 patch 作为输入,通过 linear Projection 进行编码后,加上 positional embedding,然后输入到一系列的 Transformer blocks 中。相比于 Bert 中用 mask token 来代替被 mask 区域的做法,MAE encoder 直接舍弃掉了 mask 的部分,通过这种方式可以有效的减少预训练过程中需要消耗的计算资源和训练时间。
文中,作者对编码器是否保留 mask token 进行了消融实验,可以看出在编码器阶段舍弃 mask token 不会对预训练模型的表征能力造成影响,同时能够显著的加速训练进程。
解码器
MAE decoder 由一连串的 Transfomer block 组成。和 encoder 不同的是,MAE decoder 的输入不仅包括未被 mask 的图像 patch 经过 encoder 编码后的特征,还包括了被 mask 掉的部分。对于 mask 掉部分的输入,会用一个共享参数,且可学习的 mask token 代替作为输入。除此之外,为了保证不同的 mask token 能够区分在图像中的不同位置,在输入到 decoder 之前,会对整体的输入加上 positional embedding。
在 MAE 中,解码器仅会在预训练阶段用于图片的重建工作。文中采用了轻量化的解码器结构,对于每个 token 的计算量仅有相对于解码器的 10%以下。通过这种设计,就算在解码阶段用了完整数量的 token 作为输入,对计算资源的消耗也不会显著增加。
文中,作者对解码器的 depth 和 width 两个维度进行对比实验,可以看出一个较轻量化的解码器,就足以是模型学习到有效的表征。
重建目标
MAE 预训练任务的目标是重建被 mask 掉的像素值。MAE decoder 输出关于每个图像 patch 的表征后,会经过一个 linear projection 层映射成与图像像素数目相同维度的向量(PxPx3)。仅采用 MSE 作为损失函数,计算预测向量和被 mask 掉像素值之前的 MSE loss。
需要额外指出的是,作者使用了归一化后的图像 patch 作为重建的目标。通过实验证明,这种做法可以提升模型的表征能力。
模型评价
文中除了从 linear probing 和 Finetuning 两个角度对模型的表征能力做出评价外,还采用了 Partial Fine-tuning 的方式进行评价,相比于 linear probing 这种之前普遍采用的评价方式,能够更好的反映预训练模型对非线性特征的表征能力。从下图可以看出,MAE 算法仅仅对一个 transformer block 进行 fintune 精度就从 73.5%提升到 81%。同时与 MOCOv3 相比,MOCOv3 虽然在 linear probing 的时候具有更高的精度,但是在 partial fine-tuning 时,MAE 的精度都要高于 MOCOv3。可以看出,MAE 虽然对线性特征的表征能力要弱于 MOCOv3,但是具有更好的非线性特征表征能力。
EasyCV 介绍
EasyCV 是阿里巴巴开源的基于 Pytorch,以自监督学习和 Transformer 技术为核心的 all-in-one 视觉算法建模工具。在数据层面,EasyCV 提供了提供了不同数据源(data_source)的抽象,支持多种开源数据集例如 Cifar、ImageNet、CoCo 等,并将各种数据预处理抽象成若干独立的 pipeline,可以通过配置文件灵活的配置数据预处理流程。在 API 层面,提供了统一的训练、评估、模型导出、预测的 API。因此,基于 EasyCV,仅需要实现模型部分的代码,就可以很便捷的完成 MAE 的复现。
除此之外,EasyCV 支持 aliyun PAI 产品中方便的进行部署(如 PAI-DLC),无需多余的修改即可在 DLC 上同时进行多机或者多组实验,加快复现进度。
复现过程 & 踩坑
总结接下来我们介绍如何在 EasyCV 框架中进行 MAE 算法的复现和踩坑总结,首先,说明一下预训练的整体流程。
1.将输入图像划分成不同的 patch,并将 patch 经过 Linear Projection 进行映射,再加上 positional embedding 得到 image token
2.将 image token 按 75%的比例进行随机 mask,通过随机生成的张量 noise 进行 argsort 操作的方式来完成对 image patch 的随机 mask。其中,需要注意,该函数中额外传回两个参数 mask 和 ids_restore。mask 记录了 mask patch 在原始图片中的位置,用于后续损失函数的计算。ids_restore 记录了传入 encoder 的 image token 在原始图片中的位置,用于后续再 decoder 前进行 unshuffle 操作。
3.将保留的 image token 输入到 encoder 得到 image embeding
4.将 image embeding 和 mask token 一起进行 unshuffle 操作,再加上 positional embedding 后,输入到 decoder 中
5.将输出的 vector 与归一化后的 image patch 计算 mse loss,并反向传播更新梯度。在计算 loss 时,有两个需要注意的点。1、首先,需要对作为 target 的图像 patch 做归一化。2、在计算损失函数时,只对 mask patch 的部分计算损失函数。
精度复现
参考https://github.com/facebookresearch/mae,我们在单机八卡 V100 的配置下,对 ViT-base 和 ViT-large 的在 ImageNet1K 上 fintune 的精度进行了复现。结果如下表所示。
下面分享一下在复现过程中遇到的一些问题和调参,如有问题请指出。
在 fintune 时,MAE 的实现使用了 mixup+cutmix 的数据增广方式,若仅使用 mixup 精度会下降。
在 fintune 时,MAE 中使用了所有 token 特征求平均的方式作为分类 head 的输入,而 cls token 作为输入时精度会有下降。
在预训练过程中,确保使用了足够大的 weight_decay(如官方设为 0.05),否则在下游任务 fintune 时,很容易出现梯度爆炸的问题。而在下游分类任务 fintune 时,设置一个较小的 weight,精度会有一些提升。(PS 在复现 vit-l 时,在 pretrain 时设置 weight_decay 0.01,在 fintune 时会出现梯度爆炸)
下表展示了 vit-b 模型的复现过程上述过程的精度提升
我们在开源框架 EasyCV 中复现了 MAE 算法。详细参数配置和实验日志参考 github 上的自监督 modelzoo(https://github.com/alibaba/EasyCV/blob/master/docs/source/model_zoo_ssl.md)。
Tutorial
接下来,我们将通过一个实际的例子介绍如何基于 EasyCV 进行 MAE 算法的预训练和微调,也可以在该链接查看详细步骤。
一、安装依赖包
如果是在本地开发环境运行,可以参考该链接安装环境。若使用 PAI-DSW 进行实验则无需安装相关依赖,在 PAI-DSW docker 中已内置相关环境。
二、数据准备
自监督训练只需要提供无标注图片即可进行, 你可以下载 ImageNet 数据,或者使用你自己的图片数据。需要提供一个包含若干图片的文件夹路径 p,以及一个文件列表,文件列表中是每个图片相对图片目录 p 的路径。
图片文件夹结构示例如下, 文件夹路径为./images
文件列表内容如下:
为了快速走通流程,我们也提供了一个小的示例数据集,执行如下命令下载解压:
三、模型预训练
以 vit-base 为示例。在 EasyCV 中,使用配置文件的形式来实现对模型参数、数据输入及增广方式、训练策略的配置,仅通过修改配置文件中的参数设置,就可以完成实验配置进行训练。可以直接下载示例配置文件。
查看 easycv 安装位置
执行训练命令
四、模型微调
1、对上一步得到的预训练模型的字段进行修改,以便用于 fintune 任务。
2、下载分类任务示例配置文件
3、执行训练命令
END
后续 EasyCV 会就 SOTA 论文复现进行系列的工作介绍,欢迎大家关注和使用,欢迎大家各种维度的反馈和改进建议以及技术讨论,同时我们十分欢迎和期待对开源社区建设感兴趣的同行一起参与共建。
评论 (1 条评论)