写点什么

AWS 上实现超大规模模型训练的近线性扩展

作者:qife
  • 2025-07-28
    福建
  • 本文字数:868 字

    阅读完需:约 3 分钟

AWS 上实现超大规模模型训练的近线性扩展

当前最先进的语言模型具有数十亿参数。要在可控时间内训练这些模型,需要将工作负载分配到大型计算集群上。理想情况下,训练时间应随着集群规模扩大而线性减少。但由于节点间协调工作所需的通信会抵消并行化带来的收益,实现线性扩展非常困难。


我们近期优化了微软 DeepSpeed 分布式训练库的通信效率,在最多 64 个 GPU 上显著提升了性能。但当规模从数十个 GPU 扩展到数百个 GPU 时,在公有云环境中通信开销再次成为效率瓶颈。


在即将于 2023 年 VLDB 会议上发表的论文中,我们提出了一种名为 MiCS(最小化通信规模)的方法,可在云环境中实现数百个 GPU 的高效模型训练。与 DeepSpeed 和 FairScale 等现有框架将模型状态划分到所有 GPU 不同,MiCS 会创建模型状态的多个副本,并将每个副本划分到 GPU 子集中。


实验结果显示,在不同规模的 BERT 模型上,使用 p3dn.24xlarge 实例集群评估时,MiCS 在吞吐量和扩展效率方面都有显著提升。该方法能实现近线性扩展(如下图矩形框所示),相比 DeepSpeed-v0.5.6 内置的 ZeRO 优化器,吞吐量最高提升 2.82 倍。

规模感知的模型分区

MiCS 将集群中的 GPU 划分为多个"分区组",每个组持有完整的模型状态副本。这种方法将频繁的通信操作(如参数收集)限制在固定数量的 GPU 内,有效控制了通信开销随集群规模增长的问题。

分层通信策略

当单个模型状态副本的内存需求超过单节点 GPU 总内存时,MiCS 采用分层通信策略减少节点间通信参与者的数量。例如在双节点四 GPU 场景下,通信量因子从 3/4 降至 1/2。

两跳梯度同步

MiCS 通过将梯度同步开销分摊到多个微步中,实现了高效的两跳梯度同步机制。这使得在 p4de.24xlarge 机器上训练 1750 亿参数模型时,每个 GPU 能达到 169 万亿次浮点运算(理论峰值的 54.2%)。


当集群规模从 128GPU 扩展到 512GPU 时,MiCS 实现了 99.4%的弱扩展效率,而 DeepSpeed ZeRO 第三阶段仅达到 72%。我们正在将 MiCS 开源,相信它将大幅降低在 Amazon EC2 平台上训练大模型的时间和成本。


致谢:Yida Wang, Justin Chiu, Roshan Makhijani, RJ, Stephen Rawls, Xin Jin 更多精彩内容 请关注我的个人公众号 公众号(办公 AI 智能小助手)公众号二维码


办公AI智能小助手


用户头像

qife

关注

还未添加个人签名 2021-05-19 加入

还未添加个人简介

评论

发布
暂无评论
AWS上实现超大规模模型训练的近线性扩展_AWS_qife_InfoQ写作社区