对比 PyTorch、TensorFlow、JAX、Theano,我发现都在关注两大问题
作者|王益
OneFlow 社区编译
翻译|杨婷
最近,我在处理 PyTorch 分布式和 TorchRec 相关的工作,为此,我开始学习 PyTorch 2.0。在业余时间,我也在跟着 Alpa 作者学习 JAX 和 XLA。如今回顾这些技术,我发现它们的关注点似乎都是如下两个问题:
包含自动求导和并行在内的函数转换,例如 vmap, pmap 和 pjit 等;
异构计算,CPU 负责控制流,GPU/TPU 负责张量计算和集合通信。
本文档中的所有例子都支持在 Colab 中运行:
1、函数转换
“函数转换”意为将一个程序转变成另一个程序,最常见的例子是自动求导(autograd)。自动求导采用用户编写的前向过程并创建后向过程,对于用户来说,编写自动求导通常都太过复杂。函数转换的主要难点在于:在编写函数转换算法时以何种方式表示输入和输出过程。
Theano:显式地构建 IR
Theano 是最早的深度学习工具之一,也就是如今为人们所熟知的 Aesara 项目。Theano 有一个允许用户在内存中将 IR 构建为数据结构的 API,因此 Theano 可实现自动求导,并将结果输出为 Python 函数。
TensorFlow 1.x:用于运行 IR 的虚拟机
TensorFlow 1.x 明确保留了构建 IR 的想法。若在 TensorFlow 中运行上述示例,结果不会有什么差别;但倘若在 TensorFlow 1.x 中来运行,最大的差别在于:我们不会将后向 IR 转换为 Python 函数,并使用 Python 解释器来运行。相反,我们会在 TensorFlow runtime 中来运行。
PyTorch 1.x:没有前向 IR
PyTorch 不会像 Theano 或 TensorFlow 那样将前向传播转换为 IR。反之,PyTorch 使用 Python 解释器来运行前向传播。这样做的弊端在于会在运行期间生成表示后向传播的 IR,我们称之为 Eager 模式(动态图模式)。
TensorFlow 2.x: 梯度带
TensorFlow 2.x 增加了一个像 PyTorch API 的 Eager 模式 API。此 API 追踪前向传播如何运行名为梯度带(GradientTape)的 IR 。TensorFlow 2.x 可以从这个跟踪中找出后向传播。
JAX
JAX 不会向用户公开诸如梯度带等方面的低级别细节。简单说来,JAX 的思维方式为:将输入和输出都用 Python 函数来表示。
对于想要自己编写的函数转换的高级用户,他们可以调用`make_jaxpr`等低级 API 来访问 IR,称为 JAXPR。
FuncTorch
FuncTorch 和 JAX 类似,都是基于 PyTorch 的函数转换。
JAX 的`make_jaxpr`类似于 functorch 的`make_fx`。
TensorFlow 2.x、JAX 和 functorch 都为前向传递构建了一个 IR,但 PyTorch Eager 模式没有。IR 不仅可用于自动求导,还可用于其他类型的函数转换。在下列例子中,`functorch.compile.aot_function 调用了回调函数`print_compile_fn`两次,分别用于前向和后向传播。
2、高阶导数
PyTorch
TensorFlow 2.x
JAX
3、动态控制流
动态控制流(dynamic control flows)有两个层级:在 CPU 上运行的粗粒度级别和在 GPU /TPU 上运行的细粒度级别。本部分主要介绍在 CPU 上运行的粗粒度级别的动态控制流。下面我们将用(if/else)条件语句作为例子检验深度学习工具。
TensorFlow 1.x
在 TensorFlow 1.x 中,我们需要将条件语句显式构建到 IR 中。此时条件语句是一个特殊的运算符 `tf.cond`。
TensorFlow 2.x
TensorFlow 2.x 支持使用 `tf.cond` 和 `tf.while_loop` 显式构建控制流。此外,实验项目 google/tangent 中有 AutoGraph 功能,它可以将 Python 控制流转换为`tf.cond`或`tf.while_loop`。此功能利用了 Python 解释器支持的函数和函数源代码。例如下面的 g 函数调用了 Python 的标准库将源代码解析为 AST,然后调用 SSA 表单来理解控制流。
JAX
由于部分 Python 语法很复杂,所以通过解析源代码来理解控制流就显得很困难,这就导致 AutoGraph 经常出错。但如果这种方法很简单,那么 Python 开发者社区也不会在构建 Python 编译器时失败这么多次了。正是由于有这种挑战的存在,必须要明确地将控制流构建到 IR 中。为此,JAX 提供了 `jax.lax.cond` 和 `jax.lax.for_loop`函数。
考虑到这一点,你可能会觉得我们可以使用递归算法。但是下面用于计算阶乘的递归无法用 JAX 跟踪。
可能你还想调用 factorial 来计算 3!=6。但这会让递归深度超过最大值,因为递归不仅依赖于条件,还依赖于函数定义和调用。
PyTorch
PyTorch 最初是 Python-native。正如前文所说,由于多功能调度机制,grad 和 vamp 的函数转换都是即时的。值得注意的是:
相比 Theano 和 TensorFlow 构建 IR 后的函数转换,即时函数转换效率更高。
在进行`grad`和`vmap` 时,JAX 也是即时函数转换。然而像`pamp`和`pjit`等更复杂的函数转换需要对整个计算过程进行概述,在这个过程中 IR 是必不可少的。
由于 IR 在`pmap` 和 `pjit`中的必要性,PyTorch 社区最近添加了`torch.cond`pytorch/pytorch#83154
4、分布式计算
根据执行代码或 IR 的不同方式,在使用 Python 解释器或 runtime 时,有两种分布式计算方法。
Python-Native
Theano 和 PyTorch 采用了 Python-native 分布式计算方式。这种分布式训练工作包含多个 Python 解释器进程。这导致出现了以下结果。
打包和运行(Pack and run)。由于这些 Python 进程在不同的 host 上运行,因此我们需要打包用户程序和依赖项,并将它们发送到这些 host 上去运行。一直以来 TorchX 负责了这个打包过程。它支持例如 Docker 和 torch.package 等各种打包格式,并且可以与各种集群管理器配合使用,如 Kubernetes 和 SLURM。
单程序多数据(SPMD)。由于将用户程序发送到各种 host 上要依赖于打包,与其他权重较轻的方式(如通过 RPC 发送代码)相比,这种方式不太灵活,因此,我们通常只发送一个程序。当所有这些进程运行同一程序时,这个作业就变成了单程序多数据(SPMD)作业。
Python-native SPMD
下面是一个简单的 SPMD PyTorch 程序,我们可以在相同或不同的 host 上使用进程运行这个程序。在这个过程中,我们只需要调用`all_gather`。真正的分布式训练程序会调用更高级别的 API,例如`torch.nn.parallel.DistributedDataParallel` 和 `torchrec.DistributedModelParallel`, 然后再调用低级 API,例如 `all_gather` 和 `all_reduce`。
Python-native Non-SPMD
PyTorch 不仅限于 SPMD 式的分布式训练。它还通过`torch.distributed.pipeline.sync.Pipe`和`PiPPy project`提供流水并行,其中流水并行的各个阶段在不同的设备上运行不同的程序。这些阶段常通过 `torch.rpc` 包来沟通。
分布式运行时机制
分布式 TensorFlow 作业由运行 TensorFlow runtime 程序的进程组成,而不是由 Python 解释器组成。此分布式运行时作业执行 TensorFlow graph (IR),它是由执行用户程序的 Python 解释器生成。
用户程序可以使用低级 API(如 `tf.device`)去指定作业要运行什么操作、在哪台设备和主机上运行等等。因为 API 有 runtime,所以可以做到这一点。
与 PyTorch 一样,TensorFlow 也为分布式训练提供了高级 API `tf.distributed.strategy`,Keras 和 DTensor。
分布式运行时极大地方便了训练服务的维护,因为我们不再将用户程序打包到集群上运行。相反,我们打包运行时程序,因为相比用户程序,运行时程序更加统一。
混合理念
JAX 支持 Python-native 和分布式运行时。
JAX 提供例如`vmap`、`pmap` 和 `pjit`的函数转换,这可以将 Python 函数转换为分布式程序。
(本文经授权后由 OneFlow 社区编译,译文转载请联系获得授权。原文:https://quip.com/Y8qtAyV4EXRg)
其他人都在看
欢迎 Star、试用 OneFlow 最新版本:https://github.com/Oneflow-Inc/oneflow/
版权声明: 本文为 InfoQ 作者【OneFlow】的原创文章。
原文链接:【http://xie.infoq.cn/article/abac1a7232b26a3a2ae40fa4e】。文章转载请联系作者。
评论