DeepSeek 模型 MOE 结构代码详解
其实在 DeepSeek-R1 爆火之前,DeepSeek V2 在我们行业就已经妇孺皆知了,它独特的 MOE 结构值得研究一下。这篇文章是基于 @ZOMI酱 的 2 个视频写的,这 2 个视频讲的很好,建议大家都学习一下:《MOE 终于迎来可视化解读!傻瓜都能看懂 MoE 核心原理!》和《使用昇腾 NPU 手撕 MoE 单机版代码!没想到如此简单!》。
这篇文章是把我自己的理解梳理一下,加强自己的理解和记忆。
MOE 结构概述
我们可以从 zomi 酱视频里面的这张图开始:

MOE 是 mixture of experts 的缩写,简单来说,就是把传统 transformer 结构中 decoder 层里面的单个线性层替换层多个并列的线性层。在这些线性层前面还有一个 Router,Router 会选择并列线性层里面的一部分进行计算。这样的话,既能让模型学习更多的知识(多个“专家”),又能减少推理计算量(选择部分“专家”进行计算)。
MOE 计算代码
接下来我们参考 zomi 酱提供的代码来详细看一下 MOE 的计算过程是怎样的:
初始化函数定义
首先,定义了 Expert 类,也就是“专家”,可以看到,专家是由线性层和激活函数构成的简单模型。
然后开始定义 MOE 类。在初始化函数中,定义了这样几个变量:
self.num_experts:专家的数量,也就是上面提到的“并列线性层”的个数,训练后的每个专家的权重都是不同的,代表它们所掌握的“知识”是不同的。
self.top_k:每个输入 token 激活的专家数量。
self.expert_capacity:代表计算每组 token 时,每个专家能被选择的最多次数。
self.gate:路由网络,一般是一个线性层,用来计算每个专家被选择的概率。
self.experts:实例化 Expert 类,生成多个专家。
前向计算逻辑
接下来看一下 forward 函数。为了方便大家理解,我们把上面代码的执行打印结果也一起附上。
首先是输入 x,shape 是(batch_size, input_dim),batch_size 我们可以看作是 token 的数量,也就是序列长度。然后通过 self.gate 和 softmax 计算每个 token 在每个专家上的激活概率:
probs 的打印结果如下:我们设置的 batch_size 是 10,num_experts 是 8,所以 probs 是个 10 行 8 列的矩阵。
接着,再用 topk 算子把每个 token 的激活专家选出来:
topk_probs 和 topk_indices 的打印结果如下,因为我们设置的 top_k=3,所以每个 token 都把排名前三的概率选出来了,同时 topk_indices 把这些概率对应的专家编号也选出来了,比如第 0 个 token,激活了 5 号专家、3 号专家、0 号专家。
self.training 分支对应的是训练过程中计算损失函数的部分,我们后面再讲。
选择好专家后,就要开始计算了。计算规则是,对于每个 token,假如它选择的专家是 e1、e2、e3,概率分别是 p1、p2、p3,那么这个 token 的计算结果就是 p1e1_out+p2e2_out+p3*e3_out。
由于计算个体是每个专家,所以代码中用 for 循环遍历每个专家。我们以第 0 个专家为例,看看它的计算过程是怎样的。
首先需要确定 0 号专家的输入。由于不是每个 token 都选择了 0 号专家,所以不能把 x 直接作为输入,而是要确定一个下标向量 idxes,把 x[idxes]作为 0 号专家的输入,idxes 的值就是激活了 0 号专家的所有 token 编号,那么怎么得到 idxes 呢?代码里面是这样做的:
首先计算一个 mask:
打印结果如下:
flat_indices 是 topk_indices 平铺之后的向量。通过对比,可以看到 expert_mask 中 True 的位置和 topk_indices 中 0 的位置铺平之后是一致的,代表第 0 个专家被第 0 个 token 和第 1 个 token 激活了。
而且 expert_mask 代表的含义是:只要它的第 0-2 的位置是 True 的话,就代表被第 0 个 token 激活了,只要它的第 3-5 的位置是 True 的话,就代表被第 1 个 token 激活了,以此类推,我们可以声明一个 sample_indices 向量:
再通过下面的代码就可以把 idxes 取出来了:
也顺便把概率权重取出来:
接着把输入取出来:
打印结果如下:
再进行专家计算:
最后还需要把计算结果叠加到对应的 token 上面去:
完成上面的 for 循环之后,就把所有专家的计算任务完成了,通过 index_add_的操作,把每个 token 的计算结果也汇总了。
损失函数
损失函数包含 2 部分:专家利用率均衡和样本分配均衡。
首先是专家利用率均衡,如果每个专家被选择的概率相近,那么说明分配越均衡,损失函数越小:
然后是样本分配均衡,首先得到每个 token、每个专家的分配概率矩阵:
然后按照 token 维度(样本维度)求平均,得到每个专家被分配的 token 平均数量和平均概率:
两者相乘求和得到负载均衡损失:
样本分配越均衡,这个损失函数越小。举个例子,10 个专家,10 个样本,如果所有样本都分到 1 个专家,那么损失函数值为 10x1+0+0...+0=10,如果平均分给 10 个专家,那么损失函数值为 1x0.1+1x0.1+...+1x0.1=1。
评论