# 定义每个专家的结构moe_intermediate_size = 5class DeepseekV3MLP(nn.Module): def __init__(self, hidden_size=None, intermediate_size=None): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN["silu"]
def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj# 构建专家组experts = nn.ModuleList( [ DeepseekV3MLP(hidden_dim, moe_intermediate_size) for i in range(n_routed_experts) ])
x = hidden_states # [bsz*seq_len, hidden_dim]topk_ids = topk_idx # [bsz*seq_len, top_k]cnts = topk_ids.new_zeros((topk_ids.shape[0], n_routed_experts)) # [bsz*seq_len, n_routed_experts]print("cnts: ", cnts)# cnts记录每个token的专家路由情况cnts.scatter_(1, topk_ids, 1) # [bsz*seq_len, n_routed_experts]print("cnts: ", cnts)# 统计每个专家的token数量tokens_per_expert = cnts.sum(dim=0) # [n_routed_experts]# 按照expert编号的顺序,把每个expert对应的token下标取出来idxs = topk_ids.view(-1).argsort()print("idxs: ", idxs)# 按照expert编号的顺序,把每个expert需要处理的token特征取出来sorted_tokens = x[idxs // topk_ids.shape[1]]print("sorted_tokens: ", sorted_tokens)
sorted_tokens_shape = sorted_tokens.shape# 这个脚本可以在单卡上运行ep_size = 1# 所有专家都放在一个卡上experts_per_rank = n_routed_expertsprint("tokens_per_expert.shape[0]: ", tokens_per_expert.shape[0])# 多卡EP并行场景if ep_size > 1: # [ep_size, n_routed_experts // ep_size]->[ep_size] tokens_per_ep_rank = tokens_per_expert.view(ep_size, -1).sum(dim=1) # [n_routed_experts] tokens_per_expert_group = tokens_per_expert.new_empty( tokens_per_expert.shape[0] ) dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) # tokens_per_expert_group获取的是各个rank上分给本rank的token情况 # [ep_size, n_routed_experts // ep_size] -> [ep_size] output_splits = ( tokens_per_expert_group.view(ep_size, -1) .sum(1) .cpu() .numpy() .tolist() ) # [total_token_on_this_rank, hidden_dim], 存储所有需要在本rank上计算的Token。 gathered_tokens = sorted_tokens.new_empty( tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1] ) input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist() # gathered_tokens记录了所有需要在本rank上计算的Token dist.all_to_all( list(gathered_tokens.split(output_splits)), list(sorted_tokens.split(input_split_sizes)), ) # [experts_per_rank,], 记录的是所有节点发送给本rank上各expert的token数量,[expert1_token_num, expert2_token_num,...] tokens_per_expert_post_gather = tokens_per_expert_group.view( ep_size, experts_per_rank ).sum(dim=0) gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) s = 0 for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): # 记录每个token对应的expert编号 gatherd_idxs[s : s + k] = i % experts_per_rank s += k gatherd_idxs = gatherd_idxs.argsort() # [expert_total_token_num,] sorted_tokens = gathered_tokens[gatherd_idxs] # [expert_total_token_num, hidden_dim] tokens_per_expert = tokens_per_expert_post_gather # [experts_per_rank,]tokens_per_expert = tokens_per_expert.cpu().numpy()print("tokens_per_expert: ", tokens_per_expert)
outputs = []start_idx = 0ep_rank = 0# 遍历每个专家进行计算for i, num_tokens in enumerate(tokens_per_expert): end_idx = start_idx + num_tokens if num_tokens == 0: continue expert = experts[i + ep_rank * experts_per_rank] tokens_for_this_expert = sorted_tokens[start_idx:end_idx] expert_out = expert(tokens_for_this_expert) outputs.append(expert_out) start_idx = end_idx
# 把所有专家的计算结果concate起来outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)print("outs: ", outs) # [bsz*seq_len*top_k, hidden_dim]
# EP并行情况下,需要把其他rank上的序列token在本rank上计算的结果返回if ep_size > 1: new_x = torch.empty_like(outs) # 把输出按照原来的顺序排列,即各rank给本rank发送的token顺序 new_x[gatherd_idxs] = outs gathered_tokens = new_x.new_empty(*sorted_tokens_shape) dist.all_to_all( list(gathered_tokens.split(input_split_sizes)), list(new_x.split(output_splits)), ) outs = gathered_tokens
new_x = torch.empty_like(outs)# 把outs的顺序进行重排,让从上到下是按token的顺序进行排列new_x[idxs] = outs# 把每个token的多个expert处理结果进行加权求和final_out = ( new_x.view(*topk_ids.shape, -1) .type(topk_weight.dtype) .mul_(topk_weight.unsqueeze(dim=-1)) .sum(dim=1) .type(new_x.dtype))print("final_out: ", final_out)
评论