AWS 上实现超大规模模型训练的近线性扩展
AWS 上实现超大规模模型训练的近线性扩展
当前最先进的语言模型具有数十亿参数。要在可控时间内训练这些模型,需要将工作负载分配到大型计算集群上。理想情况下,训练时间应随着集群规模扩大而线性减少。但由于节点间协调工作所需的通信会抵消并行化带来的收益,实现线性扩展非常困难。
我们近期优化了微软 DeepSpeed 分布式训练库的通信效率,在最多 64 个 GPU 上显著提升了性能。但当规模从数十个 GPU 扩展到数百个 GPU 时,在公有云环境中通信开销再次成为效率瓶颈。
在即将于 2023 年 VLDB 会议上发表的论文中,我们提出了一种名为 MiCS(最小化通信规模)的方法,可在云环境中实现数百个 GPU 的高效模型训练。与 DeepSpeed 和 FairScale 等现有框架将模型状态划分到所有 GPU 不同,MiCS 会创建模型状态的多个副本,并将每个副本划分到 GPU 子集中。
实验结果显示,在不同规模的 BERT 模型上,使用 p3dn.24xlarge 实例集群评估时,MiCS 在吞吐量和扩展效率方面都有显著提升。该方法能实现近线性扩展(如下图矩形框所示),相比 DeepSpeed-v0.5.6 内置的 ZeRO 优化器,吞吐量最高提升 2.82 倍。
规模感知的模型分区
MiCS 将集群中的 GPU 划分为多个"分区组",每个组持有完整的模型状态副本。这种方法将频繁的通信操作(如参数收集)限制在固定数量的 GPU 内,有效控制了通信开销随集群规模增长的问题。
分层通信策略
当单个模型状态副本的内存需求超过单节点 GPU 总内存时,MiCS 采用分层通信策略减少节点间通信参与者的数量。例如在双节点四 GPU 场景下,通信量因子从 3/4 降至 1/2。
两跳梯度同步
MiCS 通过将梯度同步开销分摊到多个微步中,实现了高效的两跳梯度同步机制。这使得在 p4de.24xlarge 机器上训练 1750 亿参数模型时,每个 GPU 能达到 169 万亿次浮点运算(理论峰值的 54.2%)。
当集群规模从 128GPU 扩展到 512GPU 时,MiCS 实现了 99.4%的弱扩展效率,而 DeepSpeed ZeRO 第三阶段仅达到 72%。我们正在将 MiCS 开源,相信它将大幅降低在 Amazon EC2 平台上训练大模型的时间和成本。
致谢:Yida Wang, Justin Chiu, Roshan Makhijani, RJ, Stephen Rawls, Xin Jin 更多精彩内容 请关注我的个人公众号 公众号(办公 AI 智能小助手)公众号二维码

评论