AI 简报 - 模型集成 SAM 和 SWA

1.SAM
1.1 背景
SAM 即 Sharpness-Aware Minimization for Effciently Improving Generalization,是一种提高模型泛化的 loss 策略,该文章基于的思想还是 flat minimum 比 sharp minimun 更具有泛化性。设计了一个 loss 来构建;这种 flat minimum 的效果。
1.2. 解读
如何设计这样的 loss 呢?
文章分析了传统的 SGD 在如交叉熵的 loss,在损失空间中可能存在许许多多的局部和全局的最小值。有的是尖锐的深沟,这种不利用泛化,如下中间的图示

设计的思路
考虑如果是 flat 的 minimum,那临近的损失和最优的损失差距不会很大,相较于 sharp minimum。文章在原来 loss 的权重加入一个 eps 噪声,来代表周围的最优解的情况。然就计算最大差距的最小化
1.3.方法和实现细节
文章提供的 github 实现,不过 loss 部分是统计 jax 方式实现的。https://github.com/google-research/sam

2.SWA
2.1.背景
AveragingWeights Leads to Wider Optima and Better Generalization;该论文是承袭 FGE(Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNN), 提出不是在模型空间上做融合,而是权重空间上做权重平均化的融合。以单模型就可以达到 FGE 多模型融合的效果,提高推理的效率。
2.2. 解读
论文指出权重平均的操作,能够使得损失曲面上有更宽的区域,所以有更好的泛化能力。

2.3.方法和实现细节
2.3.1 方法
采用更剧烈的 cycle learning rate schedule, 甚至是固定的学习率

训练一个初始的权重 w
以初始权重 w 开始训练,以周期为单位进行权重的平均。需要注意的如果有 BN,因为这里没有更新,所以需要对平均权重的模型, 进行 BN 参数的平均, 做一次前向的计算。

2.3.2 细节
CLR 上下限的选择


版权声明: 本文为 InfoQ 作者【AIWeker】的原创文章。
原文链接:【http://xie.infoq.cn/article/0dea613472eb53aa7744c5d52】。文章转载请联系作者。
评论