写点什么

深度学习基础 5: 交叉熵损失函数、MSE、CTC 损失适用于字识别语音等序列问题、Balanced L1 Loss 适用于目标检测

作者:汀丶
  • 2023-04-18
    浙江
  • 本文字数:3629 字

    阅读完需:约 12 分钟

深度学习基础5:交叉熵损失函数、MSE、CTC损失适用于字识别语音等序列问题、Balanced L1 Loss适用于目标检测

1.交叉熵损失函数

在物理学中,“熵”被用来表示热力学系统所呈现的无序程度。香农将这一概念引入信息论领域,提出了“信息熵”概念,通过对数函数来测量信息的不确定性。交叉熵(cross entropy)是信息论中的重要概念,主要用来度量两个概率分布间的差异。假定 p 和 q 是数据 x 的两个概率分布,通过 q 来表示 p 的交叉熵可如下计算:


交叉熵刻画了两个概率分布之间的距离,旨在描绘通过概率分布 q 来表达概率分布 p 的困难程度。根据公式不难理解,交叉熵越小,两个概率分布 p 和 q 越接近。


这里仍然以三类分类问题为例,假设数据 x 属于类别 1。记数据 x 的类别分布概率为 y,显然 y=(1,0,0)代表数据 x 的实际类别分布概率。记代表模型预测所得类别分布概率。那么对于数据 x 而言,其实际类别分布概率 y 和模型预测类别分布概率 的交叉熵损失函数定义为:


很显然,一个良好的神经网络要尽量保证对于每一个输入数据,神经网络所预测类别分布概率与实际类别分布概率之间的差距越小越好,即交叉熵越小越好。于是,可将交叉熵作为损失函数来训练神经网络。



图 1 三类分类问题中输入 x 的交叉熵损失示意图(x 属于第一类)



在上面的例子中,假设所预测中间值 (z1,z2,z3)经过 Softmax 映射后所得结果为 (0.34,0.46,0.20)。由于已知输入数据 x 属于第一类,显然这个输出不理想而需要对模型参数进行优化。如果选择交叉熵损失函数来优化模型,则 (z1,z2,z3)这一层的偏导值为 (0.34−1,0.46,0.20)=(−0.66,0.46,0.20)。


可以看出,和交叉熵损失函数相互结合,为偏导计算带来了极大便利。偏导计算使得损失误差从输出端向输入端传递,来对模型参数进行优化。在这里,交叉熵与 Softmax 函数结合在一起,因此也叫 损失(Softmax with cross-entropy loss)。

2.均方差损失(Mean Square Error,MSE)

均方误差损失又称为二次损失、L2 损失,常用于回归预测任务中。均方误差函数通过计算预测值和实际值之间距离(即误差)的平方来衡量模型优劣。即预测值和真实值越接近,两者的均方差就越小。


计算方式:假设有 n 个训练数据 ,每个训练数据 的真实输出为 ,模型对 的预测值为 。该模型在 n 个训练数据下所产生的均方误差损失可定义如下:


假设真实目标值为 100,预测值在-10000 到 10000 之间,我们绘制 MSE 函数曲线如 图 1 所示。可以看到,当预测值越接近 100 时,MSE 损失值越小。MSE 损失的范围为 0 到∞。


3.CTC 损失

3.1 CTC 算法算法背景-----文字识别语音等序列问题

CTC 算法主要用来解决神经网络中标签和预测值无法对齐的情况通常用于文字识别以及语音等序列学习领域。举例来说,在语音识别任务中,我们希望语音片段可以与对应的文本内容一一对应,这样才能方便我们后续的模型训练。但是对齐音频与文本是一件很困难的事,如 图 1 所示,每个人的语速都不同,有人说话快,有人说话慢,我们很难按照时序信息将语音序列切分成一个个的字符片段。而手动对齐音频与字符又是一件非常耗时耗力的任务



图 1 语音识别任务中音频与文本无法对齐


在文本识别领域,由于字符间隔、图像变形等问题,相同的字符也会得到不同的预测结果,所以同样会会遇到标签和预测值无法对齐的情况。如 图 2 所示。



图 2 不同表现形式的相同字符示意图


总结来说,假设我们有个输入(如字幅图片或音频信号)X ,对应的输出是 Y,在序列学习领域,通常会碰到如下难点:


  • X 和 Y 都是变长的;

  • X 和 Y 的长度比也是变化的;

  • X 和 Y 相应的元素之间无法严格对齐。

3.2 算法概述

引入 CTC 主要就是要解决上述问题。这里以文本识别算法 CRNN 为例,分析 CTC 的计算方式及作用。CRNN 中,整体流程如 图 3 所示。



图 3 CRNN 整体流程


CRNN 中,首先使用 CNN 提取图片特征,特征图的维度为,特征图 x 可以定义为:


然后,将特征图的每一列作为一个时间片送入 LSTM 中。令 t 为代表时间维度的值,且满足 ,则每个时间片可以表示为:


经过 LSTM 的计算后,使用 softmax 获取概率矩阵 y,定义为:


经过 LSTM 的计算后,使用 softmax 获取概率矩阵 ,定义为:


n 为字符字典的长度,由于 是概率,所以 。对每一列 求 argmax(),就可以获取每个类别的概率。


考虑到文本区域中字符之间存在间隔,也就是有的位置是没有字符的,所以这里定义分隔符 −来表示当前列的对应位置在图像中没有出现字符。用 代表原始的字符字典,则此时新的字符字典 为:


此时,就回到了我们上文提到的问题上了,由于字符间隔、图像变形等问题,相同的字符可能会得到不同的预测结果。在 CTC 算法中,定义了 B 变换来解决这个问题。 B 变换简单来说就是将模型的预测结果去掉分割符以及重复字符(如果同个字符连续出现,则表示只有 1 个字符,如果中间有分割符,则表示该字符出现多次),使得不同表现形式的相同字符得到统一的结果。如 图 4 所示。



这里举几个简单的例子便于理解,这里令 T 为 10:


对于字符中间有分隔符的重复字符则不进行合并:


当获得 LSTM 输出后,进行 B 变换就可以得到最终结果。由于 B 变换并不是一对一的映射,例如上边的 3 个不同的字符都可以变换为 state,所以在 LSTM 的输入为 x 的前提下,CTC 的输出为 l 的概率应该为:


其中, 为 LSTM 的输出向量, 代表所有能通过 B 变换得到 l 的 的集合。


而对于任意一个 π,又有:


其中, 代表 t 时刻 π为对应值的概率,这里举一个例子进行说明:


$\begin{array}{c}\pi=-s-t-aattte\ y_{\pi_t}^t=y_-^1y_s^2y_-^3y_t^4y_-^5y_a^6y_a^7y_t^8y_t^9*y_e^10\ \end{array}$


不难理解,使用 CTC 进行模型训练,本质上就是希望调整参数,使得 取最大。


具体的参数调整方法,可以阅读以下论文进行了解:Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks

4.平衡 L1 损失(Balanced L1 Loss)---目标检测

目标检测(object detection)的损失函数可以看做是一个多任务的损失函数,分为分类损失和检测框回归损失:


表示分类损失函数、表示检测框回归损失函数。在分类损失函数中,p 表示预测值,u 表示真实值。表示类别 u 的位置回归结果,v 是位置回归目标。λ用于调整多任务损失权重。定义损失大于等于 1.0 的样本为 outliers(困难样本,hard samples),剩余样本为 inliers(简单样本,easy sample)。


平衡上述损失的一个常用方法就是调整两个任务损失的权重,然而,回归目标是没有边界的,直接增加检测框回归损失的权重将使得模型对 outliers 更加敏感,这些 hard samples 产生过大的梯度,不利于训练。inliers 相比 outliers 对整体的梯度贡献度较低,相比 hard sample,平均每个 easy sample 对梯度的贡献为 hard sample 的 30%,基于上述分析,提出了 balanced L1 Loss(Lb)。


Balanced L1 Loss 受 Smooth L1 损失的启发,Smooth L1 损失通过设置一个拐点来分类 inliers 与 outliers,并对 outliers 通过一个进行梯度截断。相比 smooth l1 loss,Balanced l1 loss 能显著提升 inliers 点的梯度,进而使这些准确的点能够在训练中扮演更重要的角色。设置一个拐点区分 outliers 和 inliers,对于那些 outliers,将梯度固定为 1,如下图所示:



Balanced L1 Loss 的核心思想是提升关键的回归梯度(来自 inliers 准确样本的梯度),进而平衡包含的样本及任务。从而可以在分类、整体定位及精确定位中实现更平衡的训练,Balanced L1 Loss 的检测框回归损失如下:


其相应的梯度公示如下:


基于上述公式,设计了一种推广的梯度公式为:


其中,控制着 inliers 梯度的提升;一个较小的α会提升 inliers 的梯度同时不影响 outliers 的值。来调整回归误差的上界,能够使得不同任务间更加平衡。α,γ从样本和任务层面控制平衡,通过调整这两个参数,从而达到更加平衡的训练。Balanced L1 Loss 公式如下:


其中参数满足下述条件:


默认参数设置:α = 0.5,γ=1.5


Libra R-CNN: Towards Balanced Learning for Object Detection


发布于: 刚刚阅读数: 4
用户头像

汀丶

关注

本博客将不定期更新关于NLP等领域相关知识 2022-01-06 加入

本博客将不定期更新关于机器学习、强化学习、数据挖掘以及NLP等领域相关知识,以及分享自己学习到的知识技能,感谢大家关注!

评论

发布
暂无评论
深度学习基础5:交叉熵损失函数、MSE、CTC损失适用于字识别语音等序列问题、Balanced L1 Loss适用于目标检测_人工智能_汀丶_InfoQ写作社区