深度学习应用篇 - 元学习 [14]:基于优化的元学习 -MAML 模型、LEO 模型、Reptile 模型
深度学习应用篇-元学习[14]:基于优化的元学习-MAML 模型、LEO 模型、Reptile 模型
1.Model-Agnostic Meta-Learning
Model-Agnostic Meta-Learning (MAML): 与模型无关的元学习,可兼容于任何一种采用梯度下降算法的模型。MAML 通过少量的数据寻找一个合适的初始值范围,从而改变梯度下降的方向,找到对任务更加敏感的初始参数,使得模型能够在有限的数据集上快速拟合,并获得一个不错的效果。该方法可以用于回归、分类以及强化学习。
该模型的 Paddle 实现请参考链接:PaddleRec版本
1.1 MAML
MAML 是典型的双层优化结构,其内层和外层的优化方式如下:
1.1.1 MAML 内层优化方式
内层优化涉及到基学习器,从任务分布 中随机采样第 个任务 。任务 上,基学习器的目标函数是:
其中, 是基学习器, 是基学习器参数, 是基学习器在 上的损失。更新基学习器参数:
其中, 是元学习器提供给基学习器的参数初始值 ,在任务 上更新 后 .
1.1.2 MAML 外层优化方式
外层优化涉及到元学习器,将 反馈给元学匀器,此时元目标函数是:
元目标函数是所有任务上验证集损失和。更新元学习器参数:
1.2 MAML 算法流程
randomly initialize
while not done do:
sample batch of tasks
for all do:
evaluate with respect to K examples
compute adapted parameters with gradient descent: $\theta_{i}^{N}=\theta_{i}^{N-1} -\alpha\left[\nabla_{\phi}L_{T_{i}}\left(f_{\phi}\right)\right]{\phi=\theta{i}^{N-1}} $
end for
update $\theta \leftarrow \theta-\beta \sum_{T_{i} \sim p(T)} \nabla_{\theta}\left[L_{T_{i}}\left(f_{\phi}\right)\right]{\phi=\theta{i}^{N}} $
end while
MAML 中执行了两次梯度下降 (gradient by gradient),分别作用在基学习器和元学习器上。图 1 给出了 MAML 中特定任务参数 和元级参数 的更新过程。
图 1 MAML 示意图。灰色线表示特定任务所产生的梯度值(方向);黑色线表示元级参数选择更新的方向(黑色线方向是几个特定任务产生方向的平均值);虚线代表快速适应,不同的方向代表不同任务更新的方向。
1.3 MAML 模型结构
MAML 是一种与模型无关的元学习方法,可以适用于任何基于梯度优化的模型结构。
基准模型:4 modules with a 3 3 convolutions and 64 filters, followed by batch normalization, a ReLU nonlinearity, and 2 2 max-pooling。
1.4 MAML 分类结果
1.5 MAML 的优缺点
优点
适用于任何基于梯度优化的模型结构。
双层优化结构,提升模型精度和泛化能力,避免过拟合。
缺点
存在二阶导数计算
1.6 对 MAML 的探讨
每个任务上的基学习器必须是一样的,对于差别很大的任务,最切合任务的基学习器可能会变化,那么就不能用 MAML 来解决这类问题。
MAML 适用于所有基于随机梯度算法求解的基学习器,这意味着参数都是连续的,无法考虑离散的参数。对于差别较大的任务,往往需要更新网络结构。使用 MAML 无法完成这样的结构更新。
MAML 使用的损失函数都是可求导的,这样才能使用随机梯度算法来快速优化求解,损失函数中不能有不可求导的奇异点,否则会导致优化求解不稳定。
MAML 中考虑的新任务都是相似的任务,所以没有对任务进行分类,也没有计算任务之间的距离度量。对每一类任务单独更新其参数初始值,每一类任务的参数初始值不同,这些在 MAML 中都没有考虑。
参考文献
[1] Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks.
2.Latent Embedding Optimization
Latent Embedding Optimization (LEO) 学习模型参数的低维潜在嵌入,并在这个低维潜在空间中执行基于优化的元学习,将基于梯度的自适应过程与模型参数的基础高维空间分离。
2.1 LEO
在元学习器中,使用 SGD 最小化任务验证集损失函数,使得模型的泛化能力最大化,计算元参数,元学习器将元参数输入基础学习器,继而,基础学习器最小化任务训练集损失函数,快速给出任务上的预测结果。LEO 结构如图 1 所示。
图 1 LEO 结构图。 是任务 的 support set, 是任务 的 query set, 是通过编码器计算的 个类别的类别特征, 是基学习器, 是基学习器参数, , 。
LEO 包括基础学习器和元学习器,还包括编码器和解码器。在基础学习器中,编码器将高维输入数据映射成特征向量,解码器将输入数据的特征向量映射成输入数据属于各个类别的概率值,基础学习器使用元学习器提供的元参数进行参数更新,给出数据标注的预测结果。元学习器为基础学习器的编码器和解码器提供元参数,元参数包括特征提取模型的参数、编码器的参数、解码器的参数等,通过最小化所有任务上的泛化误差,更新元参数。
2.2 基础学习器
编码器和解码器都在基础学习器中,用于计算输入数据属于每个类别的概率值,进而对输入数据进行分类。元学习器提供编码器和解码器中的参数,基础学习器快速的使用编码器和解码器计算输入数据的分类。任务训练完成后,基础学习器将每个类别数据的特征向量和任务 的基础学习器参数 输入元学习器,元学习器使用这些信息更新元参数。
2.2.1 编码器
编码器模型包括两个主要部分:编码器和关系网络。
编码器 ,其中 是编码器的可训练参数,其功能是将第 个类别的输入数据映射成第 个类别的特征向量。
关系网络 ,其中 是关系网络的可训练参数,其功能是计算特征之间的距离。
第 个类别的输入数据的特征记为 。对于输入数据,首先,使用编码器 对属于第 个类别的输入数据进行特征提取;然后,使用关系网络 计算特征之间的距离,综合考虑训练集中所有样本点之间的距离,计算这些距离的平均值和离散程度;第 个类别输入数据的特征 服从高斯分布,且高斯分布的期望是这些距离的平均值,高斯分布的方差是这些距离的离散程度,具体的计算公式如下:
其中, 是类别总数, 是每个类别的图片总数, ${D}{n}^{\mathrm{tr}}nKKNz=\left(z{1}, \cdots, z_{N}\right)$。
2.2.2 解码器
解码器 ,其中 是解码器的可训练参数,其功能是将每个类别输入数据的特征向量 映射成属于每个类别的概率值 :
其中,任务 的基础学习器参数记为 ,基础学习器参数由属于每个类别的概率值组成,记为 ,基础学习器参数 $\boldsymbol{w}{n}ng{\phi_{d}}$ 是从特征向量到基础学习器参数的映射。
图 2 LEO 基础学习器工作原理图。
2.2.3 基础学习器更新过程
在基础学习器中,任务 的交叉熵损失函数是:
其中, 是任务 训练集 中的样本点, 是任务 的基础学习器,最小化任务 的损失函数更新任务专属参数 。在解码器模型中,任务专属参数为 ,更新任务专属参数 意味着更新特征向量 :
其中,$\boldsymbol{z}{n}^{\prime}\boldsymbol{\theta}{\varepsilon}^{\prime}\theta_{\varepsilon}^{\prime}\varepsilon\mathrm{D}{\varepsilon}^{\mathrm{val}}L{\varepsilon}^{\mathrm{val}}\left(f_{\theta_{\varepsilon}^{\prime}}\right)z_{n}^{\prime}\theta_{\varepsilon}^{\prime}$ 输入元学习器,在元学习器中更新元参数。
2.3 元学习器更新过程
在元学习器中,最小化所有任务 的验证集的损失函数的求和,最小化任务上的模型泛化误差:
其中, 是任务 验证集的损失函数,衡量了基础学习器模型的泛化误差,损失函数越小,模型的泛化能力越好。 是高斯分布,$D_{\mathrm{KL}}\left{q\left(z_{n} \mid {D}{n}^{\mathrm{tr}}\right) | p\left(z{n}\right)\right}q\left(z_{n} \mid D_{n}^{\text {tr }}\right)p\left(z_{n}\right)\mathrm{KL}q\left(z_{n} \mid {D}{n}^{\text {tr}}\right)\left|s\left(z{n}^{\prime}\right)-z_{n}\right|z_{n}z_{n}^{\prime}RR$ 的计算公式如下:
其中, {C}{d}\phi_{d}\lambda_{1},\lambda_{2}>0\left|C_{d}-\mathbb{I}\right|{2}C{d}\phi_{d}$ 的行和行之间的相关性不能太大, 每个类别的特征向量之间的相关性不能太大, 属于每个类别的概率值之间的相关性也不能太大,分类要尽量准确。
2.4 LEO 算法流程
LEO 算法流程
randomly initialize
let \phi=\left{\phi_{e}, \phi_{r}, \phi_{d}, \alpha\right}
while not converged do:
for number of tasks in batch do:
sample task instance
let
encode to z using and
decode to initial params using
initialize
for number of adaptation steps do:
compute training loss $\mathcal{L}{\mathcal{T}{i}}^{t r}\left(f_{\theta_{i}^{\prime}}\right)$
perform gradient step w.r.t. :
$\mathbf{z}^{\prime} \leftarrow \mathbf{z}^{\prime}-\alpha \nabla_{\mathbf{z}^{\prime}} \mathcal{L}{\mathcal{T}{i}}^{t r}\left(f_{\theta_{i}^{\prime}}\right)$
decode to obtain using
end for
compute validation loss $\mathcal{L}{\mathcal{T}{i}}^{v a l}\left(f_{\theta_{i}^{\prime}}\right)$
end for
perform gradient step w.r.t :$\phi \leftarrow \phi-\eta \nabla_{\phi} \sum_{\mathcal{T}{i}} \mathcal{L}{\mathcal{T}{i}}^{v a l}\left(f{\theta_{i}^{\prime}}\right)$
end while
(1) 初始化元参数:编码器参数 、关系网络参数 、解码器参数 , 在元学习器中更新的元参数包括 。
(2) 使用片段式训练模式,随机抽取任务 , ${D}{\varepsilon}^{\mathrm{tr}}\varepsilon{D}{\varepsilon}^{\mathrm{val}}\varepsilon$ 的验证集。
(3) 使用编码器 和关系网络 将任务 的训练集 编码成特征向量 ,使用 解码器 从特征向量映射到任务 的基础学习器参数 ${\theta}{\varepsilon}\varepsilonL{\varepsilon}^{\mathrm{tr}}\left(f_{\theta_{\varepsilon}}\right)\varepsilon$ 的损失函数,更新每个类别的特征向量:
使用解码器 从更新后的特征向量映射到更新后的任务 的基础学习器参数 ${\theta}{\varepsilon}^{\prime}\varepsilonL{\varepsilon}^{\text {val}}\left(f_{\theta_{s}^{\prime}}\right)$;基础学习器将更新后的参数和验证集损失函数值输入元学习器。
(4) 更新元参数, ,最小化所有任务 的验证集的损失和,将更新后的元参数输人基础学习器,继续处理新的分类任务。
2.5 LEO 模型结构
LEO 是一种与模型无关的元学习,[1] 中给出的各部分模型结构及参数如表 1 所示。
表 1 LEO 各部分模型结构及参数。
2.6 LEO 分类结果
2.7 LEO 的优点
新任务的初始参数以训练数据为条件,这使得任务特定的适应起点成为可能。通过将关系网络结合到编码器中,该初始化可以更好地考虑所有输入数据之间的联合关系。
通过在低维潜在空间中进行优化,该方法可以更有效地适应模型的行为。此外,通过允许该过程是随机的,可以表达在少数数据状态中存在的不确定性和模糊性。
3.Reptile
Reptil 是 MAML 的特例、近似和简化,主要解决 MAML 元学习器中出现的高阶导数问题。因此,Reptil 同样学习网络参数的初始值,并且适用于任何基于梯度的模型结构。
在 MAML 的元学习器中,使用了求导数的算式来更新参数初始值,导致在计算中出现了任务损失函数的二阶导数。在 Reptile 的元学习器中,参数初始值更新时,直接使用了任务上的参数估计值和参数初始值之间的差,来近似损失函数对参数初始值的导数,进行参数初始值的更新,从而不会出现任务损失函数的二阶导数。
Peptile 有两个版本:Serial Version 和 Batched Version,两者的差异如下:
3.1 Serial Version Reptile
单次更新的 Reptile,每次训练完一个任务的基学习器,就更新一次元学习器中的参数初始值。
(1) 任务上的基学习器记为 ,其中 是基学习器中可训练的参数, 是元学习器提供给基学习器的参数初始值。在任务 上,基学习器的损失函数是 ,基学习器中的参数经过 次迭代更新得到参数估计值:
(2) 更新元学习器中的参数初始值:
Serial Version Reptile 算法流程
initialize , the vector of initial parameters
for iteration=1, 2, ... do:
sample task , corresponding to loss on weight vectors
compute
update
end for
3.2 Batched Version Reptile
批次更新的 Reptile,每次训练完多个任务的基学习器之后,才更新一次元学习器中的参数初始值。
(1) 在多个任务上训练基学习器,每个任务从参数初始值开始,迭代更新 次,得到参数估计值。
(2) 更新元学习器中的参数初始值:
其中, 是指每次训练完 个任务上的基础学习器后,才更新一次元学习器中的参数初始值。
Batched Version Reptile 算法流程
initialize
for iteration=1, 2, ... do:
sample tasks , , ... , ,
for i=1, 2, ... , n do:
compute
end for
update
end for
3.3 Reptile 分类结果
表 1 Reptile 在 Omniglot 上的分类结果。
表 1 Reptile 在 miniImageNet 上的分类结果。
更多优质内容请关注公重号:汀丶人工智能
版权声明: 本文为 InfoQ 作者【汀丶人工智能】的原创文章。
原文链接:【http://xie.infoq.cn/article/308b078948c6b6c27e10311da】。
本文遵守【CC-BY 4.0】协议,转载请保留原文出处及本版权声明。
评论