一文详解 ATK Loss 论文复现与代码实战
本文分享自华为云社区《ATK Loss论文复现与代码实战》,作者:李长安。
损失是一种非常通用的聚合损失,其可以和很多现有的定义在单个样本上的损失 结合起来,如 logistic 损失,hinge 损失,平方损失(L2),绝对值损失(L1)等等。通过引入自由度 k,损失可以更好的拟合数据的不同分布。当数据存在多分布或类别分布不均衡的时候,最小化平均损失会牺牲掉小类样本以达到在整体样本集上的损失最小;当数据存在噪音或外点的时候,最大损失对噪音非常的敏感,学习到的分类边界跟 Bayes 最优边界相差很大;当采取损失最为聚合损失的时候(如 k=10),可以更好的保护小类样本,并且其相对于最大损失而言对噪音更加鲁棒。所以我们可以推测:最优的 k 即不是 k = 1(对应最大损失)也不是 k = n(对应平均损失),而是在[1, n]之间存在一个比较合理的 k 的取值区间。
上图结合仿真数据显示了最小化平均损失和最小化最大损失分别得到的分类结果。可以看出,当数据分布不均衡或是某类数据存在典型分布和非典型分布的时候,最小化平均损失会忽略小类分布的数据而得到次优的结果;而最大损失对样本噪音和外点(outliers)非常的敏感,即使数据中仅存在一个外点也可能导致模型学到非常糟糕的分类边界;相比于最大损失损失,第 k 大损失对噪音更加鲁棒,但其在 k > 1 时非凸非连续,优化非常困难。
由于真实数据集非常复杂,可能存在多分布性、不平衡性以及噪音等等,为了更好的拟合数据的不同分布,我们提出了平均 Top-K 损失作为一种新的聚合损失。
本项目最初的思路来自于八月份参加比赛的时候。由于数据集复杂,所以就在想一些难例挖掘的方法。看看这个方法能否带来一个更好的模型效果。该方法的主要思想是使用数值较大的排在前面的梯度进行反向传播,可以认为是一种在线难例挖掘方法,该方法使模型讲注意力放在较难学习的样本上,以此让模型产生更好的效果。代码如下所示。
topk_loss 的主要思想
topk_loss 的核心思想,即通过控制损失函数的梯度反传,使模型对 Loss 值较大的样本更加关注。该函数即为 CrossEntropyLoss 函数的具体实现,只不过是在计算 nllloss 的时候取了前 70%的梯度,
数学逻辑:挖掘反向传播前 70% 梯度。
代码实战
此部分使用比赛中的数据集,并带领大家使用 Top-k Loss 完成模型训练。在本例中使用前 70%的 Loss。
总结
在该工作中,分析了平均损失和最大损失等聚合损失的优缺点,并提出了平均 Top-K 损失(损失)作为一种新的聚合损失,其包含了平均损失和最大损失并能够更好的拟合不同的数据分布,特别是在多分布数据和不平衡数据中。损失降低正确分类样本带来的损失,使得模型学习的过程中可以更好的专注于解决复杂样本,并由此提供了一种保护小类数据的机制。损失仍然是原始损失的凸函数,具有很好的可优化性质。我们还分析了损失的理论性质,包括 classification calibration 等。
Top-k loss 的参数设置为 1 时,此损失函数将变 cross_entropy 损失,对其进行测试,结果与原始 cross_entropy()完全一样。但是我在实际的使用中,使用此损失函数却没使模型取得一个更好的结果。需要做进一步的实验。
版权声明: 本文为 InfoQ 作者【华为云开发者联盟】的原创文章。
原文链接:【http://xie.infoq.cn/article/fbef0c856bb8a2c472a00e2bc】。文章转载请联系作者。
评论