1 回顾
参照文章:Torch.compile()流程解析——3. AOTAutograd
在上一篇Ascend的aclgraph(三)TorchDynamo的介绍中,解析了torch.compile()
是如何捕获计算图并保存为 GraphModule 的,但在这个过程中只是对整个 Python 字节码进行了模拟执行、解析并构建 FX Graph,相当于只是初步构建了前向计算图,没有捕获训练场景下的反向计算图。在 PyTorch 中反向计算图的捕获是放在 backend compiler 里面实现了,以torch.compile
的默认 backend compiler——inductor 为例,在其实现函数compile_fx
中,涵盖了AOTAutograd
(捕获 fw-bw joint graph)、PrimTorch(lowering op)
和TorchInductor(图优化、Triton)
。下面开始解析 backend compiler 的默认函数——inductor 的函数实现,并梳理出剩下三个组件的原理。
2 AOTAutograd
首先介绍 AOTAutograd,AOTAutograd 是 PyTorch 引入的一种自动求导机制,旨在在模型执行之前预先生成梯度计算的代码。这种方法通过静态分析模型的前向计算图,提前生成反向传播所需的梯度计算逻辑,从而减少运行时的开销,提升训练效率。有了 AOTAutograd,开发者可以做以下事情:
获取反向传播计算图、甚至是正向传播和反向传播联合的计算图;
用不同的后端编译器分别编译正向传播和反向传播计算图;
针对训练 (training) 做正向传播、反向传播联合优化,比如通过在反向传播中重算 (recompute) 来减少正向传播为反向传播保留的 tensor,从而削减内存需求;
总的来说,AOTAutograd 的工作流程如下:
基于torch_dispatch
机制 trace 正向反向传播,生成联合计算图(joint graph)。
通过 decompositions 进一步拆解,将 FX Graph 进一步转换为更低层次的中间表示,即PrimTorch
。
通过partition_fn
将joint-graph
切分成正反向计算图。
调用fw_compiler
和bw_compiler
对正向、反向计算图分别进行编译,并整合成一个torch.autograd.Function
。
看到这里,是不是想起来了Ascend的aclgraph(二)_npu_backend中还有些什么秘密?中提到的 3 个 compile 执行函数。
return fw_compiler, inference_compiler, joint_compiler
复制代码
3 torch dispatch
AOTAutograd 是基于torch_dispatch
机制在算子下发执行前获得真正实际执行的 op,并构建对应的 Proxy,即 PyTorch 反向传播的计算图是在执行正向过程中动态创建的,这也意味着执行完整的前向过程才能构建出对应的 FX Graph,从而在函数正式执行前拿到正反向计算图,实现 AOTAutograd,而这一过程也是依赖于前面 TorchDynamo 捕获的 FX Graph 这一 IR 表示。
在正式解析 AOTAutograd 之前先了解一下torch_dispatch
机制。PyTorch 的核心是一个 dispatcher,功能是根据输入 tensor 的属性把算子 dispatch 到具体的 kernel 上,如根据 tensor 的 device 属性决定是调用 CUDA kernel 还是 CPU 实现,从而综合各项属性算出一个 dispatch key 决定调用哪个 kernel。一个算子在 PyTorch 中往往要经过多次 dispatch,而__torch_dispatch__
给开发者提供了在算子最终 dispatch 前获取对应的算子和输入的接口。
后续 AOTAutograd 实现的代码逻辑如下,感兴趣的小伙伴也可以看看后面的代码解析部分
4 Joint Graph
在Ascend的aclgraph(三)TorchDynamo中提及通过在 TorchDynamo 构建 FX Graph 后会调用call_user_compiler
()调用 backend compiler 对计算图进行编译,torch.compile()的默认编译函数实现 inductor 的入口函数是compile_fx()
,调用接口信息如上图。
分析compile_fx()
的函数调用栈,其核心实现是aot_dispatch_autograd()
函数,其主要流程如下:
调用aot_dispatch_autograd_graph
()生成前反向 joint graph。
调用partition_fn
进行切分,最后返回包含前、反向计算图的 torch.autograd.Function。
首先介绍aot_dispatch_autograd_graph
()函数生成 joint graph 的过程:
通过create_joint
()函数将正反向计算封装成函数,create_joint
()根据前向计算结果分析出需要计算梯度的参数以及对应的 tangents(梯度权值),然后通过torch.autograd.grad
进行反向求导,并将正反向过程封装在函数中返回,作为joint_fn_to_trace
。
由_create_graph
()对joint_fn_to_trace
函数进行跟踪,核心是调用make_fx
()函数在算子 dispatch 前拿到实际真正执行的 op 并创建 Proxy 添加到 FX Graph 中。
在make_fx
()函数中是通过_MakefxTracer.trace
()函数对整个函数计算过程进行跟踪并生成 GraphModule,GraphModule 中包含正反向计算对应的计算图。需要注意的是这里的正反向计算是 TorchDynamo graph break 对应的子图,即每个子图都会调用一次make_fx
生成 joint graph。捕获过程主要包括两个核心操作:
对输入输出的封装:在dispatch_trace()
->Tracer.trace()
中会为函数参数、局部变量以及输出生成对应的 Proxy。其中通过create_args_for_root
()->create_proxy
()为所有变量(函数参数和局部变量)创建类型为 placeholder 的 Proxy,在 create_proxy()中会同步创建 Node 并将其加入到 FX Graph 中,并用 Proxy 封装一下 Node。通过 create_node()为输出创建类型为 output 的 Node,并将其加入到 FX Graph 中)。
op dispatch 的捕获和封装: 在with decompose
()上下文管理中通过 self.proxy_mode 指定了ProxyTorchDispatchMode
(用于拦截和自定义张量操作的分发过程,Dispatch Mode 机制允许开发者在张量操作(如加法、矩阵乘法等)被执行时,插入自定义逻辑,以实现诸如调试、性能监控、自定义后端支持等功能,而不需要修改 Python 的核心代码)。通过重写torch_dispatch
函数指定op dispatch
过程中插入的操作,在ProxyTorchDispatchMode
中是对 op 的 decompose(拆解到 PrimTorch 规定的集合中),同时为 op 创建类型为 call_function()的 Proxy。
# aot_dispatch_autograd_graph()函数实现
# ps:这里只展示核心函数调用
def aot_dispatch_autograd_graph(
flat_fn,
flat_args: List[Any],
aot_config: AOTConfig,
*,
fw_metadata: ViewAndMutationMeta,
) -> Tuple[torch.fx.GraphModule, Tuple[List[Any], List[Any]], Optional[SubclassMeta]]:
# traced_tangents corresponds to the set of outputs in the traced forward that should get grad_outputs in the traced backward.
# It includes outputs of the original forward, *and* any updated inputs due to input mutations.
# However, it does *not* include any outputs that are aliases of inputs or intermediates, or any metadata-only input mutations.
joint_inputs = (flat_args, fw_metadata.traced_tangents)
joint_fn_to_trace = create_joint(fn_prepared_for_autograd, aot_config=aot_config) # 生成正反向计算,封装成函数
fx_g = _create_graph(joint_fn_to_trace, updated_joint_inputs, aot_config=aot_config) # 通过make_fx()函数跟踪joint_fn_to_trace的计算过程生成joint_graph,以torch.fx.GraphModule格式返回
# _create_graph()的核心实现实现
# path:torch/fx/experimental/proxy_tensor.py::class _MakefxTracer
# ps:只展示核心代码实现
def _trace_inner(self, f, *args):
phs = pytree.tree_map(lambda _: fx.PH, args) # type: ignore[attr-defined]
args = _wrap_fake(args)
func = _wrap_func(f, phs)
# We disable the autocast cache as the autocast cache causes type conversions on parameters to
# check a cache, which introduces untracked tensors into the graph
#
# We also disable tracing by any other tensor proxy-based tracers except the current. The
# purpose of `make_fx` is to produce graphmodules as a side effect; its internal execution is
# thus irrelevant to any external functional trace.
with decompose(self.decomposition_table), self.fake_tensor_mode, self.python_dispatcher_mode, self.proxy_function_mode, \
self.proxy_mode.sym_mode, self.torch_fn_metadata_mode, \
self.proxy_mode, disable_autocast_cache(), _set_make_fx_tracer(self): # 设置op decompose,涉及PrimTorch
t = dispatch_trace( # 跟踪op dispatch并将其添加到FX Graph中,最后生成GraphModule
wrap_key(func, args, self.fx_tracer, self.pre_dispatch),
tracer=self.fx_tracer,
concrete_args=tuple(phs)
)
return t
# __torch_dispatch__执行过程
r = maybe_handle_decomp(proxy_mode, func, args, kwargs) # 基于CURRENT_DECOMPOSITION_TABLE查找op对应的函数实现并返回
if r is not NotImplemented:
return r
# 不是ATen op则进一步拆解算子
# For pre-autograd tracing, we do not want to run CompositeImplicit decomps.
if not pre_dispatch and func not in [
torch.ops.aten.size.default,
torch.ops.aten.stride.default,
torch.ops.aten.storage_offset.default,
]:
with proxy_mode:
r = func.decompose(*args, **kwargs)
if r is not NotImplemented:
return r
# 对中间函数调用创建类型为call_function的Proxy
proxy_args, proxy_kwargs = pytree.tree_unflatten(proxy_flat_args_kwargs, spec)
proxy_out = proxy_mode.tracer.create_proxy(
"call_function",
func,
proxy_args,
proxy_kwargs,
name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__),
)
out = func(*args, **kwargs) # 以FakeTensor作为输入运行函数拿到对应的输出
track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer) # 将结果Tensor绑定到对应Proxy中
return out
复制代码
通过上面的torch_dispatch
和_MakefxTracer.trace
()跟踪,从而在整个joint_fn_to_trace
执行完毕后将所有操作都记录到 FX Graph 中,构建出正反向 joint graph。并在_MakefxTracer.trace
()中通过fx._lazy_graph_module._make_graph_module(tracer.root, graph, name)
生成 GraphModule 并一路返回到aot_dispatch_autograd_graph
(),对 joint graph 通过eliminate_dead_code
()进行冗余代码消除和 recompile()生成对应 python 代码,并返回到aot_dispatch_autograd
()进行后续的切分。
以经典例子my_func
()函数为例,生成的正反向 joint graph 如下,所有局部变量对应 placeholder,输出对应 output,中间计算全部对应 call_function,且每个子图对应一个 joint graph。
# 前反向jont-graph示例
opcode name target args kwargs
------------- ------ ---------------- -------------- --------
placeholder arg0_1 arg0_1 () {}
placeholder arg1_1 arg1_1 () {}
call_function sum_1 aten.sum.default (arg0_1,) {}
call_function sum_2 aten.sum.default (arg1_1,) {}
call_function gt aten.gt.Tensor (sum_1, sum_2) {}
output output output ((gt,),) {}
opcode name target args kwargs
------------- ---------- ---------------- ----------------- --------
placeholder primals_1 primals_1 () {}
placeholder tangents_1 tangents_1 () {}
call_function cos aten.cos.default (primals_1,) {}
call_function cos_1 aten.cos.default (cos,) {}
call_function sin aten.sin.default (cos,) {}
call_function neg aten.neg.default (sin,) {}
call_function mul aten.mul.Tensor (tangents_1, neg) {}
call_function sin_1 aten.sin.default (primals_1,) {}
call_function neg_1 aten.neg.default (sin_1,) {}
call_function mul_1 aten.mul.Tensor (mul, neg_1) {}
output output output ([cos_1, mul_1],) {}
复制代码
5 partition 拆分计算图
前面已经介绍了通过aot_dispatch_autograd_graph
()函数获得包含 joint graph 的 GraphModule,回到aot_dispatch_autograd
()函数,通过aot_config.partition_fn
()进行切分,这里目前内置了两种 partition_fn:
default_partition:模拟了 PyTorch 的默认行为,找出从 forward 的输入到输出的所有算子输出,剩余部分都视为 backward 部分,从而分割出正反向 graph,forward 的所有中间结果都保存用于 backward。
min_cut_rematerialization_partition:通过在 backward 中引入重计算,减少 forward 给 backward 保留的 tensor 以节省显存占用,这种重计算的思路与 gradient/activation checkpointing 一致。除了 backward 的输入 Tensor(即直接参与 tangents 计算的 Tensors,也叫 tangent’s closure)必须要保留外,其余的 Tensors 有多种去留方案,但如何选择 forward 保留给 backward 的 Tensors 做到计算和显存之间的 tradeoff,这里采用的是求解最大流最小割(max-flow/min-cut) 问题的方式,流程如下:
在源节点 (source,虚拟添加) 和 primals(forward 输入 Tensors)之间各添加一条边,在所有的 tangent’s closure(backward 输入 Tensors)和目标节点 (sink,虚拟添加)之间各添加一条边,它们组成了一张从 source 到 sink 的有向图,边上的权重是 tensor size 即代表显存占用大小;
需要找到一个合适的切分方法,把这个有向图分成两部分,使得 source 子图到 target 子图之间边上的权重之和最小,这是一个最小割问题;
最小割问题的对等问题是最大流问题,已经有标准的解法,因此直接在该有向图上运行最大流算法即可得到最佳划分方法,将最小割集合上的 Tensors 作为 forward 保留的 Tensors 集合,剩余 Tensors 删除并在 backward 过程中重计算。
inductor 编译函数在compile_fx
()函数中默认定义了分割算法为min_cut_rematerialization_partition
实现,以my_func
()函数为例,对于其中 if/else 子图的切分,不保留 cos 计算的中间结果,而在 backward 过程中重计算(这里的例子比较简单无法显示实际效果。。。)。
# joint-graph示例
============original joint graph
opcode name target args kwargs
------------- ---------- ---------------- ----------------- --------
placeholder primals_1 primals_1 () {}
placeholder tangents_1 tangents_1 () {}
call_function cos aten.cos.default (primals_1,) {}
call_function cos_1 aten.cos.default (cos,) {}
call_function sin aten.sin.default (cos,) {}
call_function neg aten.neg.default (sin,) {}
call_function mul aten.mul.Tensor (tangents_1, neg) {}
call_function sin_1 aten.sin.default (primals_1,) {}
call_function neg_1 aten.neg.default (sin_1,) {}
call_function mul_1 aten.mul.Tensor (mul, neg_1) {}
output output output ([cos_1, mul_1],) {}
======================forward graph
opcode name target args kwargs
------------- --------- ---------------- --------------------- --------
placeholder primals_1 primals_1 () {}
call_function cos aten.cos.default (primals_1,) {}
call_function cos_1 aten.cos.default (cos,) {}
output output output ([cos_1, primals_1],) {}
======================backward graph
opcode name target args kwargs
------------- ---------- ---------------- ----------------- --------
placeholder primals_1 primals_1 () {}
placeholder tangents_1 tangents_1 () {}
call_function cos aten.cos.default (primals_1,) {}
call_function sin aten.sin.default (cos,) {}
call_function neg aten.neg.default (sin,) {}
call_function mul aten.mul.Tensor (tangents_1, neg) {}
call_function sin_1 aten.sin.default (primals_1,) {}
call_function neg_1 aten.neg.default (sin_1,) {}
call_function mul_1 aten.mul.Tensor (mul, neg_1) {}
output output output ([mul_1],) {}
======================end
复制代码
到此,完成了 TorchDynamo 的 FX Graph 捕获,并通过 AOTAutograd 实现正反向计算图的跟踪和切分,同时在 op dispatch 过程中通过 decompose 实现拆解为 PrimTorch 的规范算子集合。
6 总结
在前期分析的 torchair 的源代码中,FX Graph 是 complie 相关函数输入的第一个参数,joint graph,default_partition,call_function 都是常见的对象,看完本篇,相应大家对这些概念有个基础的了解。下一篇章,主要介绍 Inductor 相关的优化部分。
评论