数据并行:提升训练吞吐的高效方法 |深度学习分布式训练专题
数据并行是大规模深度学习训练中非常成熟和常用的并行模式。本文将介绍数据并行的原理和主流实现方案,使用数据并行加速训练过程中需要注意的问题,以及如何优化数据并行进一步提高训练速度。希望能帮助用户更好的理解和使用数据并行策略。
什么是数据并行
在近年来的深度学习模型训练中,使用更多的训练数据和更大的模型趋势未改。更大的模型和数据量意味着更多的计算量和存储需求,也意味着更久的训练时间。那么如何将计算和存储需求分布到多个训练设备来提升训练速度,是关键问题。
数据并行(data parallelism)是解决上述问题的的一种并行策略,其主要逻辑遵循 Single Program Multiple Data 的原则,即在数据并行的模型训练中,训练任务被切分到多个进程(设备)上,每个进程维护相同的模型参数和相同的计算任务,但是处理不同的数据(batch data)。通过这种方式,同一全局数据(global batch)下的数据和计算被切分到了不同的进程,从而减轻了单个设备上的计算和存储压力。*Single Program Multiple Data:https://en.wikipedia.org/wiki/SPMD
在深度学习模型训练中,数据并行可作为通过增加并行训练设备来提高训练吞吐量(global batch size per second) 的方法。以常见的 ResNet50 模型使用 32GB V100 卡训练为例。假设训练时单卡最大能支持的 local batch size 为 256,训练一个 step 的耗时为 1 秒。则单卡训练时的吞吐为 256 imgs/s。
如果我们使用 32 张 V100 做数据并行训练,假设没有损耗,那么理论上的训练吞吐可达到 32 x 256 = 8192 imgs/。实际上由于数据并行时多机多卡的通信消耗等,实际加速效率会有折扣,但在加速效率为 0.8 时,训练吞吐也可达到 32 x 256 x 0.8 = 6554 imgs/s。如果使用更多的 GPU,并行训练的速度将会更高,大大减少训练需要的时间。
深度学习训练中数据并行的实现方式可以有多种,下文介绍的数据并行是基于 Distributed Synchronous SGD 的梯度同步数据并行,这是目前主流深度学习训练框架中数据并行的实现方式。此外,还会介绍数据并行实现所需要注意的问题以及如何优化来让数据并行实现更高的加速比,提升训练速度。*Distributed Synchronous SGD:https://arxiv.org/pdf/1602.06709.pdf
在飞桨框架中进行数据并行训练的示例可以参考飞桨数据并行接口文档。*飞桨数据并行接口文档:https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/DataParallel_cn.html#dataparallel
数据并行的过程
相比其它的并行模式,数据并行的实现过程比较简单,关键是实现 Single Program Multiple Data 并行模式中的要求:
Single Program: 在深度学习训练中 single program 可以理解为每个进程上模型的组网和参数相同。
Multiple Data: 在深度学习训练中为每个进程上模型处理不同 mini-batch 的数据。
2.1 输入数据切分第二个条件 —— 输入数据切分实现上比较简单,一般有两种常用的实现方式:方式一:在每个训练 Epoch 开始前,将整个训练数据集根据并行进程数划分,每个进程只读取自身切分的数据。方式二: 数据的读取仅由具体某个进程负责(假设为 rank0)。rank0 在数据读取后同样根据并行进程数将数据切分成多块,再将不同数据块发送到对应进程上。
方式一相对方式二不需要进行数据通信,训练效率更高,飞桨框架中默认的数据并行使用方式一完成数据在不同进程上的切分。
2.2 模型参数同步数据并行实现的关键问题在于如何保证训练过程中每个进程上模型的参数相同。因为训练过程的每一个 step 都会更新模型参数,每个进程处理不同的数据会得到不同的 Loss。由 Loss 计算反向梯度并更新模型参数后,如何保证进程间模型参数正确同步,是数据并行需要解决的最主要问题。根据下面中的梯度更新公式,只要保证以下两点就能解决这个问题:
保证每个进程模型参数初始相同有两种常用的实现方法:方法一:所有进程在参数初始时使用相同的随机种子并以相同的顺序初始化所有参数。
方法二:通过个具体进程初始化全部模型参数,之后由该进程向其他所有进程广播模型参数。
基于上述任意一种方法使每个进程得到一份相同的模型初始化参数后,梯度同步的数据并行训练就可以进一步拆解为如下三个部分:
2.2.1 前向计算每个进程根据自身得到的输入数据独立前向计算,因为输入数据不同每个进程会得到不同的 Loss。
2.2.2 反向计算每个进程根据自身的前向计算独立进行反向计算,因为每个进程上的 Loss 不同,每个进程上在反向中会计算出不同的梯度。这时一个关键的操作是要在后续的更新步骤之前,对所有进程上的梯度进行同步,保证后续更新步骤中每个进程使用相同的全局梯度更新模型参数。
这一个梯度同步过程是用一个 Allreduce sum 同步通信操作实现的,对梯度使用 Allreduce sum 操作后每个进程上得到的梯度是相同的,这时候的梯度值等于所有进程上梯度对应位置相加的和,然后每个进程用 Allreduce 后的梯度和除以数据并行中的进程数,这样得到的梯度是同步之前所有进程上梯度的平均值。如下图所示。
2.2.3 参数更新每个进程经过上述步骤后得到相同全局梯度,然后各自独立地完成参数更新。因为更新前模型各进程间的参数是相同的,更新中所使用的梯度也是相同的,所以更新后各进程上的参数也是相同的。
上述是主流框架中数据并行的实现过程。和单卡训练相比,最主要的区别在于反向计算中的梯度需要在所有进程中进行同步,保证每个进程上最终得到的是所有进程上梯度的平均值。
数据并行训练中的注意问题
3.1 SyncBatchNorm 前面提到,一般情况下各进程前向计算是独立的,不涉及同步问题。但使用批归一化(Batch Normalization)技术的场景下有新的挑战。
批归一化通过对输入 tensor 在 batch size 维度做归一化来提升训练过程的数值稳定性。但是数据并行训练中 global batch size 被切分到不同的进程之上,每个进程上只有部分的输入数据,这样批归一化在计算输入 tensor batch 维度的平均值(Mean)和方差(Variance) 时仅使用了部分的 batch 而非 global batch,会导致部分对 batch size 比较敏感的模型(e.g. 图像分割)的精度下降。
这类模型在数据并行训练中可以使用 SyncBatchNorm 策略来保证模型精度,该策略在模型训练前向 BN 层计算 mean 和 variance 时加入额外的同步通信,使用所有数据并行进程上的 tensors 而非自身进程上的 tensor 来计算 tensor batch 维度的 mean 和 variance。具体过程如下图所示:
每个进程根据自己部分的数据计算 batch 维度上的 local sum 和 local square sum 值。
在所有卡间同步得到 global sum 和 global square sum。
使用 global sum 和 global square sum 计算 global mean 和 global standard deviation。
最后使用 global 的 mean 和 standard deviation 对 batch data 进行归一化。
像语言类模型中主要使用的 Layer Normalization,是在单个数据而非批数据的维度输入 tensor 计算 mean 和 variance,数据并行并不会影响其计算逻辑,不需要像 Batch Normalization 一样做专门的调整。
3.2 数据切分均匀目前主流训练框架数据并行训练中使用 Allreduce 同步通信操作来实现所有进程间梯度的同步,这要求数据在各进程间的切分要做到尽量均匀,这个问题看起来很简单,但在实际实现中也要特别注意以下两点:
1.要求所有进程每个训练 step 输入的 local batch size 大小相同。这是因为模型训练时需要的是所有样本对应梯度的全局平均值。如果每个进程的 local batch size 不相同,在计算梯度平均值时,除了要在所有进程间使用 Allreduce 同步梯度,还需要要同步每个进程上 local batch size。
当限制所有进程上的 local batch size 相同时,各进程可以先在本地计算本进程上梯度的 local 平均值,然后对梯度在所有进程间做 Allreduce sum 同步,同步后的梯度除以进程数得到的值就是梯度的全局平均值。这样实现可以减少对 local batch size 同步的需求,提升训练速度。
2.要保证所有进程上分配到相同的 batch 数量。因为 Allreduce 是同步通信操作,需要所有进程同时开始并同时结束一次通信过程。当有的进程的 batch 数量少于其它进程时,该进程会因为没有新的数据 batch 而停止训练,但其他进程会继续进行下一 batch 的训练;当进入下一 batch 训练的进程执行到第一个 Allreduce 通信操作时,会一直等待其他所有进程到达第一个 Allreduce 一起完成通信操作。
但因为缺少 batch 的进程,已经停止训练不会执行这次 allreduce 操作,导致其它进程将会一直等待,呈现挂死态。数据并行中 batch 数量在进程的均匀切分通常是由 data loader 来保障的,如果训练数据集样本数无法整除数据并行进程数,那么有一种策略是部分拿到多余样本的进程可以通过抛弃最后一个 batch 来保证所有进程 batch 数量的一致。
数据并行的优化技巧
4.1 通信融合(Fuse Allreduce)从上文我们知道数据并行中需要同步每一个模型梯度,这是通过进程间的 Allreduce 通信实现的。如果一个模型有非常多的参数,则数据并行训练的每一个 step 中会有非常多次的 Allreduce 通信。
通信的耗时可以从通信延迟(lantency)和数据传输时间消耗两方面考虑。单次通信延迟时间相对固定,而传输时间由通信的数据量和带宽决定。减少总的通信消耗,可以通过减少通信频率来实现,通信融合是一个可行的手段,通过将 N 个梯度的 Allreduce 通信合并成一次 Allreduce 通信,可以减少 N-1 次通信延迟时间。
常用的 Allreduce 融合实现方式是在通信前将多个梯度 tensors 拼接成一个内存地址连续的大 tensor,梯度同步时仅对拼接后的大 tensor 做一次 Allreduce 操作。参数更新时将大 tensor 切分还原回之前的多个小 tensors, 完成每个梯度对应参数的更新。
4.2 通信计算重叠(Overlapping)除了降低绝对的通信耗时,还可以从降低整体训练耗时角度来优化,可以考虑通信和计算的异步流水实现。数据并行中的梯度同步 Allreduce 通信是在训练的反向过程中进行的,而 Allreduce 后得到的同步梯度是在训练的更新过程中才被使用,在反向中并没有被使用。也就是说上一个梯度的通信和下一个梯度的计算间并没有依赖,通信和计算可以并行,让两者的耗时相互重叠掩盖,减少反向的耗时。
通信和计算的重叠通常是将通信和计算算子调度到不同的流(stream)上实现的。通信算子调度到通信流,计算算子调度到计算流,同一个流上的算子间是顺序执行的,不同流上的算子可以并行执行,从而实现反向中梯度通信和计算的并行重叠。需要注意的是,当通信和计算被调度在不同的流上执行时,需要考虑两个流之间依赖和同步关系。
在梯度同步的数据并行场景中,开发者需要需要通过 stream 间的同步功能保证:
某个梯度 Allreduce 通信进行前,该梯度的反向计算已经完成。
某个梯度对应参数的更新计算开始前,该梯度的 Allreduce 通信已经完成。
以上两个方法是数据并行中常用的减少通信时间消耗,提高并行加速比的优化策略。如果能做到通信和计算的重叠程度越高,那么数据并行的加速比越接近 100%,多卡并行对训练吞吐提升的效率也就越高。
总结与结论
本文介绍了深度学习训练中的数据并行,介绍了基于 distributed synchronous SGD 的梯度同步数据并行实现方式和训练前向、反向、更新的过程;另外还介绍了使用数据并行中批归一化结合使用时需要注意的问题和常用的数据并行训练速度优化技巧。这些都是工程上实现数据并行时需要考虑的主要问题,希望能帮助读者在工程实现角度更进一步理解数据并行。
但是在另一方面,数据并行在增大训练的 global batch size 后,虽然增加了模型的训练吞吐,但模型的收敛可能会受到影响。这是数据并行在算法层面需要解决的大 batch size 收敛问题。针对这类算法问题,感兴趣的读者可以参考 LARS 和 LAMB 等 layer-wise-lr-adaptive 优化算法。
百度 AI 开发者社区百度AI开发者社区 ,为全国各地开发者提供一个交流、分享、答疑解惑的平台,让开发者在研发路上不再“孤军奋战”,通过不断地交流与探讨找出更好的技术解决方案。如果你想尝试各种人工智能技术、开拓应用场景,赶快加入百度 AI 社区,你对 AI 的所有畅想,在这里都可以实现!
评论