写点什么

结合 RNN 与 Transformer 双重优点,深度解析大语言模型 RWKV

  • 2024-07-05
    广东
  • 本文字数:6901 字

    阅读完需:约 23 分钟

结合RNN与Transformer双重优点,深度解析大语言模型RWKV

本文分享自华为云社区《【云驻共创】昇思MindSpore技术公开课 RWKV 模型架构深度解析》,作者:Freedom123。

一、前言


Transformer 模型作为一种革命性的神经网络架构,于 2017 年由 Vaswani 等人 提出,并在诸多任务中取得了显著的成功。Transformer 的核心思想是自注意力机制,通过全局建模和并行计算,极大地提高了模型对长距离依赖关系的建模能力。但是 Transformer 在处理长序列时面临内存和计算复杂度的问题,因为其复杂度与序列长度呈二次关系一直为业内人员所诟病。今天我们学习的 RWKV,它作为对 Transformers 模型的替代,正在引起越来越多的开发人员的关注。RWKV 模型以简单、高效、可解释性强等特点,成为自然语言处理领域的新宠。下面让我们一起来学习 RWKV 模型。

二、RWKV 简介


RWKV(Receptance Weighted Key Value)是一个结合了 RNN 与 Transformer 双重优点的模型架构,由香港大学物理系毕业的彭博首次提出。其名称源于其 Time-mix 和 Channel-mix 层中使用的四个主要模型元素:R(Receptance):用于接收以往信息;W(Weight):是位置权重衰减向量,是可训练的模型参数; K(Key):是类似于传统注意力中 K 的向量; V(Value):是类似于传统注意力中 V 的向量。


RWKV 模型作为一种革新性的大型语言模型,结合了 RNN 的线性复杂度和 Transformer 的并行处理优势,引入了 Token shift 和 Channel Mix 机制来优化位置编码和多头注意力机制,解决了传统 Transformer 模型在处理长序列时的计算复杂度问题。RWKV 在多语言处理、小说写作、长期记忆保持等方面表现出色,可以主要应用于自然语言处理任务,例如文本分类、命名实体识别、情感分析等。


三、 RWKV 模型的演进


RWKV 模型之所以发展到今天的结构经历了五个阶段,从 RNN 结构到 LSTM 结构,到 GRU 结构,到 GNMT 结构,到 Transformers 结构,最后到 RMKV 结构,下面我们一一来学习每种模型结构,务必做到对模型结构都有一个清晰的认识。


1.RNN 结构


RNN(Recurrent Neural Network)是循环神经网络的缩写,是一种深度学习模型,特别适用于处理序列数据。RNN 具有记忆功能,可以在处理序列数据时保留之前的信息,并将其应用于当前的计算中。RNN 的特点在于其具有循环连接的结构,使得信息可以在网络中传递并被持续更新。RNN 由一个个时间步组成,每个时间步的输入不仅包括当前时刻的输入数据,还包括上一个时间步的隐藏状态,这样就可以在处理序列数据时考虑到上下文信息,这种结构使得 RNN 能够处理不定长的序列数据,如自然语言文本、时间序列数据等。RNN 结构如下图所示:



左边是 RNN 网络,右边是 RNN 网络按时序展开的形式,为什么要按照时序展开?主要是 RNN 中 隐状态更新需要依赖上一次的隐状态信息,就是我们理解的记忆信息。RNN 的基本结构包括一个隐藏层,其中的神经元通过时间步骤连接,允许信息从一个时间步骤传递到下一个时间步骤。RNN 在每个时间步骤上接收一个输入并输出一个隐藏状态。这个隐藏状态包含了网络在当前时间步骤所看到的序列的信息。这个隐藏状态可以被用作下一个时间步骤的输入。对于一个时间步骤 t,RNN 的隐藏状态的计算如下:



尽管 RNN 具有处理序列数据的能力,但它们在处理长序列时会面临梯度消失或梯度爆炸的问题。这是因为通过时间反向传播时,梯度可能会迅速缩小或增大,导致模型难以学习长期依赖关系。为了解决梯度消失问题,出现了一些改进的 RNN 变体,如长短时记忆网络(LSTM)和门控循环单元(Gated Recurrent Unit,GRU)。这些模型通过引入门控机制,允许网络选择性地记住和遗忘信息,从而更有效地处理长序列。

2.LSTM 结构


LSTM 全称 Long Short Term Memory networks,是普通 RNN 的变体,可以有效解决长期依赖的问题。LSTM 的核心是元胞(cell)状态,输入的信息从上方的水平线经过元胞,期间只与其他实线进行少量交互,表示一些线性变换,这使得输入的信息能够较完整的保存下来,也就是说可以保留长期记忆。而 LSTM 对信息进行选择性的保留,是通过门控机制进行实现的。门结构可以控制通过元胞信息的多少。它实际上是对输入信息进行线性变换后,再通过一个 sigmoid 层来实现的,最终将输入最终转为一个系数向量,值的范围在 0~1,可以理解为保留的信息的占比。如果值为 0,则表示将对应的信息全部丢弃,如果为 1 则表示将对应的信息全部保留。LSTM 共有三种门结构,分别是遗忘门、输入门、输出门,LSTM 结构如下图所示:



3.GRU 结构


GRU (Gated Recurrent Unit)是一种用于循环神经网络(RNN)的门控机制,旨在解决长期依赖问题并缓解梯度消失或爆炸现象。GRU 的结构比 LSTM (Long Short-Term Memory)更简单,它包含两个门:更新门(update gate)和重置门(reset gate)。更新门负责控制前一时刻的状态信息对当前时刻状态的影响,其值越大,表明引入的前一时刻状态信息越多。重置门则控制忽略前一时刻状态信息的程度,其值越小,表明忽略得越多。GRU 结构如下图所示:




最后一个步骤是更新记忆阶段,此阶段同时进遗忘和记忆两个步骤,使用同一个门控同时进行遗忘和选择记忆(LSTM 是多个门控制) 。

4.GNMT 结构


NMT 是神经网络翻译系统,通常会含用两个 RNN,一个用来接受输入文本,另一个用来产生目标语句,但是这样的神经网络系统有三个弱点:1.训练速度很慢并且需要巨大的计算资源,由于数量众多的参数,其翻译速度也远低于传统的基于短语的翻译系统(PBMT);2.对罕见词的处理很无力,而直接复制原词在很多情况下肯定不是一个好的解决方法;3.在处理长句子的时候会有漏翻的现象。而 GNMT 中,RNN 使用的是 8 层(实际上 Encoder 是 9 层,输入层是双向 LSTM。)含有残差连接的神经网络,残差连接可以帮助某些信息,比如梯度、位置信息等的传递。同时,attention 层与 decoder 的底层以及 encoder 的顶层相连接,如下图所示:



GNMT encoder 将输入语句变成一系列的向量,每个向量代表原语句的一个词,decoder 会使用这些向量以及其自身已经生成的词,生成下一个词。encoder 和 decoder 通过 attention network 连接,这使得 decoder 可以在产生目标词时关注原语句的不同部分。上面提到,多层堆叠的 LSTM 网络通常会比层数少的网络有更好的性能,然而,简单的错层堆叠会造成训练的缓慢以及容易受到剃度爆炸或梯度消失的影响,在实验中,简单堆叠在 4 层工作良好,6 层简单堆叠性能还好的网络很少见,8 层的就更罕见了,为了解决这个问题,在模型中引入了残差连接,如图:



一句话的译文所需要的关键词可能在出现在原文的任何位置,而且原文中的信息可能是从右往左的,也可能分散并且分离在原文的不同位置,因为为了获得原文更多更全面的信息,双向 RNN 可能是个很好的选择,在本文的模型结构中,只在 Encoder 的第一层使用了双向 RNN,其余的层仍然是单向 RNN,粉色的 LSTM 从左往右的处理句子,绿色的 LSTM 从右往左,二者的输出先是连接,然后再传给下一层的 LSTM,如下图 Bi-directions RNN 示意图:


5.Transformers 结构


Transformer 模型是一种基于自注意力机制的神经网络模型,旨在处理序列数据,特别是在自然语言处理领域得到了广泛应用。Transformer 模型的核心是自注意力机制(Self-Attention Mechanism),它允许模型关注序列中每个元素之间的关系。这种机制通过计算注意力权重来为序列中的每个位置分配权重,然后将加权的位置向量作为输出。模型结构上,Transformer 由一个编码器堆栈和一个解码器堆栈组成,它们都由多个编码器和解码器组成。编码器主要由多头自注意力(Multi-Head Self-Attention)和前馈神经网络组成,而解码器在此基础上加入了编码器-解码器注意力模块。Transformer 结构如下所示:



基于 Transformer 结构的编码器和解码器结构上图所示,左侧和右侧分别对应着编码器(Encoder)和解码器(Decoder)结构。它们均由若干个基本的 Transformer 块(Block)组成(对应着图中的灰色框)。这里 N× 表示进行了 N 次堆叠。每个 Transformer 块都接收一个向量序列。主要涉及到如下几个模块:


1)嵌入表示层:对于输入文本序列,首先通过输入嵌入层(Input Embedding)将每个单词转换为其相对应的向量表示。通常直接对每个单词创建一个向量表示。由于 Transfomer 模型不再使用基于循环的方式建模文本输入,序列中不再有任何信息能够提示模型单词之间的相对位置关系。在送入编码器端建模其上下文语义之前,一个非常重要的操作是在词嵌入中加入位置编码(Positional Encoding)这一特征。具体来说,序列中每一个单词所在的位置都对应一个向量。这一向量会与单词表示对应相加并送入到后续模块中做进一步处理。在训练的过程当中,模型会自动地学习到如何利用这部分位置信息。


2)注意力层:自注意力(Self-Attention)操作是基于 Transformer 的机器翻译模型的基本操作,在源语言的编码和目标语言的生成中频繁地被使用以建模源语言、目标语言任意两个单词之间的依赖关系。给定由单词语义嵌入及其位置编码叠加得到的输入表示{xi ∈ Rd}ti=1,为了实现对上下文语义依赖的建模,进一步引入在自注意力机制中涉及到的三个元素:查询 qi(Query),键 ki(Key),值 vi(Value)。在编码输入序列中每一个单词的表示的过程中,这三个元素用于计算上下文单词所对应的权重得分。直观地说,这些权重反映了在编码当前单词的表示时,对于上下文不同部分所需要的关注程度。


3)前馈层:前馈层接受自注意力子层的输出作为输入,并通过一个带有 Relu 激活函数的两层全连接网络对输入进行更加复杂的非线性变换。实验证明,这一非线性变换会对模型最终的性能产生十分重要的影响。



其中 W1, b1,W2, b2 表示前馈子层的参数。实验结果表明,增大前馈子层隐状态的维度有利于提升最终翻译结果的质量,因此,前馈子层隐状态的维度一般比自注意力子层要大。


4) 残差连接与层归一化:由 Transformer 结构组成的网络结构通常都是非常庞大。编码器和解码器均由很多层基本的 Transformer 块组成,每一层当中都包含复杂的非线性映射,这就导致模型的训练比较困难。因此,研究者们在 Transformer 块中进一步引入了残差连接与层归一化技术以进一步提升训练的稳定性。具体来说,残差连接主要是指使用一条直连通道直接将对应子层的输入连接到输出上去,从而避免由于网络过深在优化过程中潜在的梯度消失问题。


Transformer 模型由于其处理局部和长程依赖关系的能力以及可并行化训练的特点而成为一个强大的替代方案,如 GPT-3、ChatGPT、GPT-4、LLaMA 和 Chinchilla 等都展示了这种架构的能力,推动了自然语言处理领域的前沿。尽管取得了这些重大进展,Transformer 中固有的自注意力机制带来了独特的挑战,主要是由于其二次复杂度造成的。这种复杂性使得该架构在涉及长输入序列或资源受限情况下计算成本高昂且占用内存。这也促使了大量研究的发布,旨在改善 Transformer 的扩展性,但往往以牺牲一些特性为代价。正是在此背景之下,一个由 27 所大学、研究机构组成的开源研究团队,联合发表论文《 RWKV: Reinventing RNNs for the Transformer Era 》,文中介绍了一种新型模型:RWKV(Receptance Weighted Key Value),这是一种新颖的架构,有效地结合了 RNN 和 Transformer 的优点,同时规避了两者的缺点。RWKV 能够缓解 Transformer 所带来的内存瓶颈和二次方扩展问题,实现更有效的线性扩展,同时保留了使 Transformer 在这个领域占主导的一些性质。

四、 RWKV 模型


RWKV 是一个结合了 RNN 与 Transformer 双重优点的模型架构,是一个 RNN 架构的模型,但是可以像 transformer 一样高效训练。RWKV 模型通过 Time-mix 和 Channel-mix 层的组合,以及 distance encoding 的使用,实现了更高效的 Transformer 结构,并且增强了模型的表达能力和泛化能力。Time-mix 层与 AFT(Attention Free Transformer)层相似,采用了一种注意力归一化的方法,以消除传统 Transformer 模型中存在的计算浪费问题。Channel-mix 层则与 GeLU(Gated Linear Unit)层相似,使用了一个 gating mechanism 来控制每条通道的输入和输出。另外,RWKV 模型采用了类似于 AliBi 编码的位置编码方式,将每个位置的信息添加到模型的输入中,以增强模型的时序信息处理能力。这种位置编码方式称为 distance encoding,它考虑了不同位置之间的距离衰减特性,RWKV 结构如下图所示:



这里我们以下图的自回归例子学习 RWKV 的推理过程,用 (x,y ) 表示样本数据和样本标签,图中有 3 对数据: (my,name) , (my name ,is ) 和 (my name is , Bob) . 另外,在语言模型中,标记偏移(token shift)是一种常见的技术,用于训练模型以预测给定上下文中下一个标记(单词、字符或子词单元)的任务。下图中的标记偏移技术是向右移动一个位置,生成的三个 token-shift 为:"0 my","my name","name is"。为什么要进行标记偏移呢? 这是因为这样做具有递归嵌套的思想,比如:"name"向量与"my"向量有关,而"is"向量与"name"向量有关,所以"is"向量自然与"name"向量有关。好处是:给融入循环神经网络思想带来了便利的同时还保持了并行性。具体流程下面的 Time-Mix 模块和 Channel-Mix 模块会详细介绍。如下图所示,这两个模块是 RWKV 架构的主要模块。Time-Mix 模块可以看成根据隐状态(State)生成候选预测向量,Channel-Mix 模块则可以看成生成最终的预测向量。


1.Time Mixing 模块




2.Channel Mixing 模块


3.RWKV 的优势


1)高效训练和推理:RWKV 模型既可以像传统 Transformer 模型一样高效训练,也具有类似于 RNN 的推理能力。这使得 RWKV 模型可以支持串行模式和高效推理,也可以支持并行模式(并行推理训练)和长程记忆。


2)支持高效训练:RWKV 模型使用了 Time-mix 和 Channel-mix 层,以消除传统 Transformer 模型中存在的计算浪费问题。这使得 RWKV 模型在训练过程中具有更高的效率和更快的速度。


3)支持大规模自然语言处理任务:RWKV 模型可以处理大规模的自然语言处理任务,如文本分类、命名实体识别、情感分析等。


4)可扩展性强:RWKV 模型具有良好的可扩展性,可以方便地进行模型扩展和改进,以适应不同任务的需求。

4.RWKV 模型参数


目前官方已经就 RWKV 开源了多个模型。主要是 Raven 系列模型,Raven 是基于 RWKV-4 架构在 Pile 数据集上训练和微调的大模型,做过指令微调或者 chat 微调版本。此外,也包括了非 Raven 版本的 RWKV-4 的模型。


五、 RWKV 模型代码阅读

1.RWKV 模型推理代码



代码解释:


1~2 行:引入代码需要的库


4 行:对输出进行校验


6~9 行:加载 RWKV/rwkv-4-169m-pile 模型,并且输入提示词


11~12 行:运行模型,解码生成内容


13 行:期望输出与真实输出内容进行校验

2.Channel Mixing 模块代码:



x 通道混合层接受与此标记对应的输入,以及 x 与前一个标记对应的输入,我们称之为 last_x。last_x 存储在这个 RWKV 层的 state. 其余输入是学习 RWKV 的 parameters。首先,我们使用学习的权重对 x 和进行线性插值 last_x。我们将此插值 x 作为输入运行到具有平方 relu 激活的 2 层前馈网络,最后与另一个前馈网络的 sigmoid 激活相乘(在经典 RNN 术语中,这称为门控)。请注意,就内存使用而言,矩阵 Wk,Wr,Wv 包含几乎所有参数(1024×1024 matrices 它们是矩阵,而其他变量只是 1024 维向量)。矩阵乘法(@在 python 中)贡献了绝大多数所需的计算。

3.Time mixing 模块代码:


时间混合的开始类似于通道混合,通过将此标记的插入 x 到最后一个标记的 x。然后我们应用学到的矩阵以获得“key”, “value” and “receptance”向量。


六、与其他模型的比较

1.复杂度对比


从和 Transformer,Reformer,Performer,Linear Transformers,AFT-full,AFT-local,MEGA 等模型的复杂度比较中可以看的出来,RWKV 模型的时间复杂度和空间负责度都是最低的,费别为 O(Td)和 O(d),其中 T 表示序列长度,d 表示特征维度,c 表示 MEGA 的二次注意力块大小。


2.精度对比


RWKV 似乎可以像 SOTA transformer 一样缩放。至少多达 140 亿个参数。在同等规模参数中,RWKV-4 系列与 Pythia 和 GPT-J 比都是很有优势的,对比如下图所示:


3.推理速度和内存占用


RWKV 网络与不同类型的 Transformer 性能的实验结果对比如下图所示。RWKV 时间消耗随序列长度是线性增加,且时间消耗远小于各种类型的 Transformer。



RWKV 与 Transformer 预训练模型(BLOOM、OPT、Pythia)效果对比测试如下图所示。在六个基准测试中(Winogrande、PIQA、ARC-C、ARC-E、LAMBADA 和 SciQ),RWKV 与开源二次复杂度 transformer 模型 Pythia、OPT 和 BLOOM 具有相当的竞争力。RWKV 甚至在四个任务(PIQA、OBQA、ARC-E 和 COPA)中胜过了 Pythia 和 GPT-Neo。



下图显示,增加上下文长度会导致 Pile 上的测试损失降低,这表明 RWKV 能够有效利用较长的上下文信息。


七、小结


本节我们学习了 RWKV 模型,我们掌握了 RWKV 模型结构的整个演进过程,从最初的 RNN 结构,到 LSTM 结构,到 GRU 结构,到 GNTM 模型,到 Transformers 模型,最后到 RWKV 模型,我们学习了每种模型结构出现的原因,以及其对应的优势和不足。接下来,我们学习了 RWKV 模型,Time Mixing 模块和 Channel Mixing 模块。我们通过学习 RWKV 模型的 python 代码,对 RWKV 模型从复杂度,精度,推理速度,内存占用等四个维度和其他模型进行了对比。


通过本节学习,我们对 RWKV 模型有了一个全面的认识,RWKV 模型正在作为一颗在大模型领域的新星正在受到越来越多社区开发者的关注,希望 RWKV 模型在接下来的版本迭代过程中能给大家带来更多的惊喜。


点击关注,第一时间了解华为云新鲜技术

发布于: 29 分钟前阅读数: 5
用户头像

提供全面深入的云计算技术干货 2020-07-14 加入

生于云,长于云,让开发者成为决定性力量

评论

发布
暂无评论
结合RNN与Transformer双重优点,深度解析大语言模型RWKV_深度学习_华为云开发者联盟_InfoQ写作社区