1.8 PyTorch分布式
因为本书以PyTorch作为主线,穿插结合其他框架,所以先来介绍一下PyTorch分布式的历史脉络和基本概念,看看一个机器学习系统如何一步一步进入分布式世界并且完善其功能。
1.8.1 历史脉络
关于PyTorch分布式的历史,笔者参考其发布版本,把发展历史大致分成7个阶段,分别如下。
• 使用torch.multiprocessing封装了Python原生Multiprocessing模块,这样可以利用多个CPU核。
• 导入THD(Distributed PyTorch),拥有了用于分布式计算的底层库。
• 引入torch.distributed包,允许在多台机器之间交换张量,从而可以在多台机器上使用更大的批量进行训练。
• 发布C10D库,这成为torch.distributed包和torch.nn.parallel.DistributedDataParallel包的基础后端,同时废弃THD。
• 提供了一个分布式RPC(Remote Procedure Call)框架用来支持分布式模型并行训练。它允许远程运行函数和引用远程对象,而无须复制周围的真实数据,并提供自动求导(Autograd)和优化器(Optimizer)API进行反向传播和跨RPC边界更新参数。
• 引入了弹性训练,TorchElastic提供了torch.distributed.launchCLI的一个严格超集,并增加了容错和弹性功能。
• 引入了流水线并行,也就是torchgpipe。
PyTorch的历史脉络如图1-16所示。
图1-16
1.8.2 基本概念
PyTorch分布式相关的基础模块包括Multiprocessing模块和torch.distributed模块,下面分别进行介绍。
1.Multiprocessing模块
PyTorch的Multiprocessing模块封装了Python原生的Multiprocessing模块,在API上百分之百兼容,同时注册了定制的Reducer(归约器)类,可以使用IPC机制(共享内存)让不同的进程对同一份数据进行读写。但是其工作方式在CUDA上有很多弱点,比如必须规定各种进程的生命周期如何,导致CUDA上的Multiprocessing模块的处理结果经常与预期不符。
2.torch.distributed模块
PyTorch中的torch.distributed模块针对多进程并行提供了通信原语,使得这些进程可以在一个或多个计算机上运行的几个Worker之间进行通信。torch.distributed模块的并行方式与Multiprocessing(torch.multiprocessing)模块不同,torch.distributed模块支持多个通过网络连接的机器,并且用户必须为每个进程显式启动主训练脚本的单独副本。
在单机且同步模型的情况下,torch.distributed或者torch.nn.parallel.DistributedDataParallel同其他数据并行方法(如torch.nn.DataParallel)相比依然会具有优势,具体如下。
• 每个进程维护自己的优化器,并在每次迭代中执行一个完整的优化step。由于梯度已经聚集在一起并且是跨进程平均的,因此梯度对于每个进程都相同,这意味着不需要参数广播步骤,大大减少了在节点之间传输张量所花费的时间。
• 每个进程都包含一个独立的Python解释器,消除了额外的解释器开销和GIL颠簸,这些开销来自单个Python进程驱动多个执行线程、多个模型副本或多个GPU的开销。这对于严重依赖Python Runtime(运行时)的模型尤其重要,这样的模型通常具有递归层或许多小组件。
从PyTorch v1.6.0开始,torch.distributed可以分为三个主要组件,具体如下。
• 集合通信(C10D)库:torch.distributed的底层通信主要使用集合通信库在进程之间发送张量,集合通信库提供集合通信API和P2P通信API,这两种通信API分别对应另外两个主要组件DDP和RPC。其中DDP使用集合通信,RPC使用P2P通信。通常,开发者不需要直接使用此原始通信API,因为DDP和RPC可以服务于许多分布式训练场景。但在某些实例中此API仍然有用,比如分布式参数平均。
• 分布式数据并行训练组件(DDP):DDP是单程序多数据训练范式。它会在每个进程上复制模型,对于每个模型副本其输入数据样本都不相同。在每轮训练之后,DDP负责进行梯度通信,这样可以保持模型副本同步,而且梯度通信可以与梯度计算重叠以加速训练。
• 基于RPC的分布式训练组件(torch.distributed.rpc包):该组件旨在支持无法适应数据并行训练的通用训练结构,如参数服务器范式、分布式流水线并行,以及DDP与其他训练范式的组合。该组件有助于管理远程对象生命周期并将自动求导引擎扩展到机器边界之外,支持通用分布式训练场景。torch.distributed.rpc有四大支柱,具体如下。
■ RPC:支持在远端Worker上运行给定的函数。
■ Remote Ref:有助于管理远程对象的生命周期。
■ 分布式自动求导:将自动求导引擎扩展到机器边界之外。
■ 分布式优化器:可以自动联系所有参与的Worker,以使用分布式自动求导引擎计算的梯度来更新参数。
图1-17(见彩插)展示了PyTorch分布式包的内部架构和逻辑关系。
图1-17