# 定义每个专家的结构
moe_intermediate_size = 5
class DeepseekV3MLP(nn.Module):
def __init__(self, hidden_size=None, intermediate_size=None):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN["silu"]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
# 构建专家组
experts = nn.ModuleList(
[
DeepseekV3MLP(hidden_dim, moe_intermediate_size)
for i in range(n_routed_experts)
]
)
x = hidden_states # [bsz*seq_len, hidden_dim]
topk_ids = topk_idx # [bsz*seq_len, top_k]
cnts = topk_ids.new_zeros((topk_ids.shape[0], n_routed_experts)) # [bsz*seq_len, n_routed_experts]
print("cnts: ", cnts)
# cnts记录每个token的专家路由情况
cnts.scatter_(1, topk_ids, 1) # [bsz*seq_len, n_routed_experts]
print("cnts: ", cnts)
# 统计每个专家的token数量
tokens_per_expert = cnts.sum(dim=0) # [n_routed_experts]
# 按照expert编号的顺序,把每个expert对应的token下标取出来
idxs = topk_ids.view(-1).argsort()
print("idxs: ", idxs)
# 按照expert编号的顺序,把每个expert需要处理的token特征取出来
sorted_tokens = x[idxs // topk_ids.shape[1]]
print("sorted_tokens: ", sorted_tokens)
sorted_tokens_shape = sorted_tokens.shape
# 这个脚本可以在单卡上运行
ep_size = 1
# 所有专家都放在一个卡上
experts_per_rank = n_routed_experts
print("tokens_per_expert.shape[0]: ", tokens_per_expert.shape[0])
# 多卡EP并行场景
if ep_size > 1:
# [ep_size, n_routed_experts // ep_size]->[ep_size]
tokens_per_ep_rank = tokens_per_expert.view(ep_size, -1).sum(dim=1)
# [n_routed_experts]
tokens_per_expert_group = tokens_per_expert.new_empty(
tokens_per_expert.shape[0]
)
dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) # tokens_per_expert_group获取的是各个rank上分给本rank的token情况
# [ep_size, n_routed_experts // ep_size] -> [ep_size]
output_splits = (
tokens_per_expert_group.view(ep_size, -1)
.sum(1)
.cpu()
.numpy()
.tolist()
)
# [total_token_on_this_rank, hidden_dim], 存储所有需要在本rank上计算的Token。
gathered_tokens = sorted_tokens.new_empty(
tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1]
)
input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist()
# gathered_tokens记录了所有需要在本rank上计算的Token
dist.all_to_all(
list(gathered_tokens.split(output_splits)),
list(sorted_tokens.split(input_split_sizes)),
)
# [experts_per_rank,], 记录的是所有节点发送给本rank上各expert的token数量,[expert1_token_num, expert2_token_num,...]
tokens_per_expert_post_gather = tokens_per_expert_group.view(
ep_size, experts_per_rank
).sum(dim=0)
gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
s = 0
for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
# 记录每个token对应的expert编号
gatherd_idxs[s : s + k] = i % experts_per_rank
s += k
gatherd_idxs = gatherd_idxs.argsort() # [expert_total_token_num,]
sorted_tokens = gathered_tokens[gatherd_idxs] # [expert_total_token_num, hidden_dim]
tokens_per_expert = tokens_per_expert_post_gather # [experts_per_rank,]
tokens_per_expert = tokens_per_expert.cpu().numpy()
print("tokens_per_expert: ", tokens_per_expert)
outputs = []
start_idx = 0
ep_rank = 0
# 遍历每个专家进行计算
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
expert = experts[i + ep_rank * experts_per_rank]
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = expert(tokens_for_this_expert)
outputs.append(expert_out)
start_idx = end_idx
# 把所有专家的计算结果concate起来
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
print("outs: ", outs) # [bsz*seq_len*top_k, hidden_dim]
# EP并行情况下,需要把其他rank上的序列token在本rank上计算的结果返回
if ep_size > 1:
new_x = torch.empty_like(outs)
# 把输出按照原来的顺序排列,即各rank给本rank发送的token顺序
new_x[gatherd_idxs] = outs
gathered_tokens = new_x.new_empty(*sorted_tokens_shape)
dist.all_to_all(
list(gathered_tokens.split(input_split_sizes)),
list(new_x.split(output_splits)),
)
outs = gathered_tokens
new_x = torch.empty_like(outs)
# 把outs的顺序进行重排,让从上到下是按token的顺序进行排列
new_x[idxs] = outs
# 把每个token的多个expert处理结果进行加权求和
final_out = (
new_x.view(*topk_ids.shape, -1)
.type(topk_weight.dtype)
.mul_(topk_weight.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
print("final_out: ", final_out)
评论