分布式机器学习 (Parameter Server)
分布式机器学习中,参数服务器(Parameter Server)用于管理和共享模型参数,其基本思想是将模型参数存储在一个或多个中央服务器上,并通过网络将这些参数共享给参与训练的各个计算节点。每个计算节点可以从参数服务器中获取当前模型参数,并将计算结果返回给参数服务器进行更新。
为了保持模型一致性,通常采用下列两种方法:
将模型参数保存在一个集中的节点上,当一个计算节点要进行模型训练时,可从集中节点获取参数,进行模型训练,然后将更新后的模型推送回集中节点。由于所有计算节点都从同一个集中节点获取参数,因此可以保证模型一致性。
每个计算节点都保存模型参数的副本,因此要定期强制同步模型副本,每个计算节点使用自己的训练数据分区来训练本地模型副本。在每个训练迭代后,由于使用不同的输入数据进行训练,存储在不同计算节点上的模型副本可能会有所不同。因此,每一次训练迭代后插入一个全局同步的步骤,这将对不同计算节点上的参数进行平均,以便以完全分布式的方式保证模型的一致性,即 All-Reduce 范式
PS 架构
在该架构中,包含两个角色:parameter server 和 worker
parameter server 将被视为 master 节点在 Master/Worker 架构,而 worker 将充当计算节点负责模型训练
整个系统的工作流程分为 4 个阶段:
Pull Weights: 所有 worker 从参数服务器获取权重参数
Push Gradients: 每一个 worker 使用本地的训练数据训练本地模型,生成本地梯度,之后将梯度上传参数服务器
Aggregate Gradients:收集到所有计算节点发送的梯度后,对梯度进行求和
Model Update:计算出累加梯度,参数服务器使用这个累加梯度来更新位于集中服务器上的模型参数
可见,上述的 Pull Weights 和 Push Gradients 涉及到通信,首先对于 Pull Weights 来说,参数服务器同时向 worker 发送权重,这是一对多的通信模式,称为 fan-out 通信模式。假设每个节点(参数服务器和工作节点)的通信带宽都为 1。假设在这个数据并行训练作业中有 N 个工作节点,由于集中式参数服务器需要同时将模型发送给 N 个工作节点,因此每个工作节点的发送带宽(BW)仅为 1/N。另一方面,每个工作节点的接收带宽为 1,远大于参数服务器的发送带宽 1/N。因此,在拉取权重阶段,参数服务器端存在通信瓶颈。
对于 Push Gradients 来说,所有的 worker 并发地发送梯度给参数服务器,称为 fan-in 通信模式,参数服务器同样存在通信瓶颈。
基于上述讨论,通信瓶颈总是发生在参数服务器端,将通过负载均衡解决这个问题
将模型划分为 N 个参数服务器,每个参数服务器负责更新 1/N 的模型参数。实际上是将模型参数分片(sharded model)并存储在多个参数服务器上,可以缓解参数服务器一侧的网络瓶颈问题,使得参数服务器之间的通信负载减少,提高整体的通信效率。
代码实现
定义网络结构:
如上定义了一个简单的 CNN
实现参数服务器:
get_weights 获取权重参数,update_model 更新模型,采用 SGD 优化器
实现 worker:
Pull_weights 获取模型参数,push_gradients 上传梯度
训练
训练数据集为 MNIST
版权声明: 本文为 InfoQ 作者【这我可不懂】的原创文章。
原文链接:【http://xie.infoq.cn/article/f2f028ef1ea2f70d9ce4ed595】。
本文遵守【CC-BY 4.0】协议,转载请保留原文出处及本版权声明。
评论