Ascend 的 aclgraph(三)TorchDynamo

1 基本介绍
参考:【编译系列】Torch.compile()流程解析——2. TorchDynamo针对 TorchDynamo 的介绍,该已经讲的比较仔细,本篇文章,主要是基于原文并加上自己的理解介绍。
在上一篇Ascend的aclgraph(2)_npu_backend中还有些什么秘密?,解释了 torch.compile 出现的背景并初步了解了其使用和基础组。
先回顾下 torch.compile 主要包含四个基础组件:
TorchDynamo:从 python bytecode 中解析构建计算图,是一个动态的、Python 级别的编译器,旨在捕捉 PyTorch 模型的动态执行路径,将其转换为优化代码,实现 bytecode-to-bytecode 编译。
AOTAutograd:是 PyTorch 引入的一种自动求导机制,旨在在模型执行之前预先生成梯度计算的代码。这种方法通过静态分析模型的前向计算图,提前生成反向传播所需的梯度计算逻辑,从而减少运行时的开销,提升训练效率。
PrimTorch:将 2000+ Pytorch op 规范化为约 250 个原始 op 封闭集合,是 PyTorch 的一个中间表示(Intermediate Representation, IR)层,负责将高层的 PyTorch 操作转换为更低层次的、适合进一步优化和编译的基础操作。它通过简化和标准化操作,提高编译器和后端优化器处理计算图的效率。
TorchInductor:深度学习编译器,可为多种加速器和后端生成代码,生成 OpenAI Triton(Nvidia/AMD GPU)和 OpenMP/C++(CPU)代码,它利用多种优化技术,包括内存优化、并行化和低层次的代码生成,以最大化计算性能。
本章我们将解释四个基础组件中的TorchDynamo
。
2 TorchDynamo
PyTorch Dynamo 是 torch.compile 编译器的前端,能够将 eage 代码编译为 FX Graph,进而提交给 lowering 的编译器(Inductor 等)进行编译,最终生成优化过后的底层机器代码,达到加速的效果。TorchDynamo 的基本工作流程是基于 PEP 523(Python Enhancement Proposal)在函数执行前拿到 Python 字节码,通过解析并模拟执行每条 Python 字节码逐步创建 FX Graph,对 if/else、loop 或不支持的操作,会触发 graph break 生成 sub-graph。主要具备一下功能:
在代码执行前从 PVM 层拦截用户代码(基于 PEP 523)
模拟用户代码直接, 理解代码逻辑
最后将其他编译组件优化后的代码交还给 PVM。
例如对于上面的 my_func()函数会生成三个 subgraph。
其中opcode
有六种:
placeholder 对应输入,
call_method/call_function/call_module 对应函数/方法/模型调用,
utput 对应输出,
target 是函数调用,
name 是 op 结果名称,
args 和 kwargs 是参数。
2.1 call_method/call_function
在 PyTorch 的 torch.compile 中,特别是在使用 torch.fx 进行模型追踪和转换时,call_method 和 call_function 是两种不同的操作类型,它们用于表示计算图中的不同节点。理解这两者的区别对于有效地利用 torch.fx 进行模型优化和变换非常重要。
call_method 定义:call_method 用来调用某个对象的方法。这意味着你有一个特定的对象实例,并且想要调用该实例上的一个方法。使用场景:当你需要对一个对象(如张量、模块等)调用其定义的方法时使用。例如,调用一个张量的 .size() 或 .reshape() 方法。
call_function 定义:call_function 用于直接调用 Python 函数。这与 call_method 不同,因为它不依赖于任何特定的对象实例。使用场景:当你想调用一个独立的函数,比如内置的数学运算函数或其他不依赖于特定对象的方法时使用。例如,torch.add 或者 torch.relu。
主要区别
依赖性:call_method 需要一个对象实例作为调用上下文,而 call_function 则不需要,它可以直接调用全局或局部定义的函数。
灵活性:由于 call_function 不依赖于特定对象,因此它可以提供更大的灵活性,尤其是在进行函数式编程或编写与特定类无关的通用代码时。
表达方式:两者在计算图中表示的方式不同,这也影响了如何对它们进行变换和优化。理解这些差异有助于更好地操控计算图,实现更高效的模型优化。
总之,call_method 和 call_function 在 torch.fx 中分别用于表示对对象方法的调用和对自由函数的调用,正确地使用它们可以提高模型追踪和转换的准确性和效率。
2.2 示例
TorchDynamo 工作示意图

优点:
动态优化:能够处理包含动态控制流的模型,如循环和条件语句,适用于动态计算图。
自动化:用户无需手动编写转换代码,可以自动识别和优化执行路径。
灵活性:支持多种后端优化器,如 TorchInductor,提供多样化的性能提升方案。
缺点:
工具:作为较新的优化工具,可能在某些极端场景下还不够稳定或全面支持所有 PyTorch 功能,尚未提供序列化/反序列化 API。
依赖后端:最终性能提升依赖于所使用的后端优化器,某些后端可能在特定硬件或模型上表现不佳。 和其他静态图构建方式相比,TorchDynamo 更为灵活且支持更多复杂的操作,而不需要用户做大量的代码修改适配。
torch 的几种成图方式对比

3 CPython 代码的执行过程 & PEP 523:
在正式进入 TorchDynamo 工作过程之前先了解 CPython 的工作流程。首先介绍 Cpython 中两个重要的对象——PyCodeObject
和PyFrameObject
。
PyCodeObject
保存二进制字节码、常量表、变量名表等静态信息;PyFrameObject
是一个用于表示执行环境的对象,每次函数调用时,Python 都会创建一个新的PyFrameObject
,其中包含了该函数的 PyCodeObject,以及一些其他运行需要的信息,如存放局部变量的内存空间和evaluation stake
(函数调用栈)等。
CPython 在执行 Python 函数前会将 Python 代码编译为字节码,由 Python 虚拟机(PVM)中_PyEval_EvalFrameDefault()函数逐条执行编译好的字节码,而 PEP 523 提供了一个 API 接口让用户在 PVM 执行字节码之前获得待执行的字节码,从而可以对字节码进行优化修改实现即时编译(JIT Compiler)的效果。 TorchDynamo 正是基于 PEP 523 把 TorchDynamo 的编译逻辑引入到 Python 代码的解释执行过程中。
通过 CPython 提供的_PyInterpreterState_SetEvalFrameFunc
()函数把 CPython 中用于执行字节码的默认函数给替换为custom_eval_frame_shim
()。 在执行用户想要编译的函数时便会进入_custom_eval_frame_shim
().
在_custom_eval_frame
函数中,会先通过lookup
函数检查 cache 中是否有已编译代码,若存在则直接调用eval_custom_code
函数执行,从而避免重复编译相同函数。若 cache 未命中,则通过call_callback
调用回调函数进行编译,并通过set_extra
()将编译结果保存在 PyFrameObject 中,最后调用eval_custom_code
继续进行执行。而这里的回调函数也即前面在torch._dynamo.optimize
传入的回调函数:convert_frame.convert_frame(backend, hooks=hooks)
(包含编译入口 compile_fn)。
因此 torch.compile 只有在第一次正式执行代码前才会进行编译,这也导致测试编译代码的时间时需要考虑数据预热。
到此,解释了torch.compile
是如何在 Python 代码执行过程中引入TorchDynamo
的,接下来回到torch._dynamo.optimize
解析是如何一步步从字节码构建 FX Graph。
4 TorchDynamo 模拟执行 & FX Graph 构建
回到torch._dynamo.optimize
设置的回调函数convert_frame.convert_frame(backend, hooks=hooks)
,其核心函数是_compile()
,用于负责对字节码进行编译。先看代码。
根据如上的代码逻辑,抽象出如下的流程图。

在_compile()
中会对缓存大小进行判断,如果缓存大小超过配置会有警告信息,默认值为 64,含义是对于同一个 Python 函数,如果函数的输入张量信息组合变化超过 64 种,TorchDynamo 则不会继续编译用户指定的函数。 在_compile 中通过transform_code_object(code, transform)
对用户代码进行优化转换,其中code
是PyCodeObject
类型,即待编译优化的字节码,transformer
是转换函数。在transform_code_object
函数中,cleaned_instructions
()用来预处理字节码指令,通过 Python 标准库 dis.get_instructions(code)
获取字节码指令,对字节码进行清洗(例如对跳转指令做标准化处理),并转为结构化数据表示 Instruction,方便后续的优化。
例如对我们的样例函数 my_func 进行 cleaned_instructions()后的结果如下。
清理后的字节码通过transformations(instructions, code_options)
函数进行处理,核心 transform()函数实现如下。
首先实例化了InstructionTranslator
对象,在InstructionTranslator
中有一个 OutputGraph 的实例,用于保存InstructionTranslator
编译后的输出,以 torch.fx.Graph
表示。
其中transform()
(也即 TorchDynamo)构建 FX Graph 的核心模块有两个:
InstructionTranslator
的初始化过程,负责对变量构建对应的 Proxy,对应 FX Graph 中的placeholder
部分; -InstructionTranslator.run()
负责模拟运行字节码并构建对应的 Node 添加到 FX Graph 中;
在InstructionTranslator
的初始化过程中,通过 PyCodeObject 对象的 co_varnames 字段获取待编译函数中的变量名,并为每一个变量创建一个LazyVariableTracker
,作为symbolic_locals
。其中LazyVariableTracker
是一种推迟创建给定底层值的VariableTracker
,直到访问该值才创建用于节省空间资源,而VariableTracker
被用于记录每个 Python 变量对应的类型信息用于构建静态图。
其中LazyVariableTracker
通过 VariableBuilder
来生成实际对象。以torch.Tensor
为例,VariableBuilder 的创建逻辑如下:
首先通过
create_graph_input
()在 FX Graph 中创建了类型为 placeholder 的 FX Proxy(FX Proxy 是 FX symbolic tracing 中的 symbol,placeholder 对应变量,即前面的 opcode)。install_guards
()函数创建了类型为 GuardBuilder.TYPE_MATCH 的 Guard 对象,Guard 在 TorchDynamo 中负责检测被编译函数所引用的外部数据信息是否发生变化,如果没有发生变化则可以复用之前编译好的函数,否则需要重新编译该函数。TYPE_MATCH 主要判断两者的数据类型是否一致,TENSOR_MATCH 主要对输入 Tensor 的 shape、stride 等信息进行检查是否发生改变。wrap_fx_proxy
()为刚刚创建的 Proxy 建立实际的 VariableTracker,核心逻辑实现在wrap_fx_proxy_cls
()函数:在
wrap_fx_proxy_cls
()函数中首先通过wrap_to_fake_tensor_and_record
()函数为运行时获得的 torch.Tensor 创建 FakeTensor(默认情况下,TorchDynamo 使用 FakeTensor 捕获计算图而不是真实的 torch.Tensor,FakeTensor 具有和真实 torch.Tensor 相同的张量信息,但没有实际的数据和张量内存分配);通过
specialize
()函数特化张量信息(包括 dtype、device 等),在 static shape 模式下还会特化 size、stride、is_contiguous 信息,而在 dynamic shape 模式下则不会特化这部分信息。最后通过
target_cls
创建对应的VariableTracker
对象,例如这里的是 torch.Tensor,则创建的是 TensorVariable(VariableTracker 的子类),用于记录 Pytorch 中的 torch.Tensor 类型数据的相关信息。
因此,在InstructionTranslator
对象初始化创建VariableTracker
的过程中,TorchDynamo 完成了:
在 FX Graph 中创建 FXProxy
添加 Guard 和 FakeTensor 相关操作并初始化 VariableTracker,由于并不是所有的局部变量都会被当前 frame 用到,为了节省资源开销这里采用 LazyVariableTracker,只有到实际使用的时候才会进行创建。
到此完成输入对应的 VariableTracker 创建,会在后续一直带着 Guard、FakeTensor 等信息用于跟踪 Tensor 的后续操作。
完成InstructionTranslator
对象的初始化,回到InstructionTranslator.run()
函数,由于InstructionTranslator
继承于class InstructionTranslatorBase
,所以这里实际调用的是InstructionTranslatorBase.run()
函数。InstructionTranslatorBase 的本质是一个 Python 虚拟机的模拟器,在循环中对字节码逐条解析模拟执行加粗样式,对其核心函数step()
进行分析,首先基于 instruction_pointer 获取待执行的字节码指令,通过 dispatch_table 映射表获取到每个 op 对应的函数调用并进入函数解析当前字节码指令,当遇到循环、if/else 等跳转相关的字节码指令时会触发 compile_subgraph()函数进入子图编译相关操作。代码如下:
对于每条字节码的模拟执行和解析,以下面的CALL_FUNCTION
函数调用为例,会先根据argval
弹出对应的函数参数,并进一步调用TensorVariable.call_method()
函数。在call_method()
函数中,proxy_args_kwargs()
函数从symbolic_locals
中获取相应的函数参数 Proxy,然后调用create_proxy()
创建新的 Proxy,类型是call_method
,并有对应的 method 名(如 my_func 中的 x.sum(),对应 TorchDynamo 中的 target 项)和参数。最后通过wrap_fx_proxy
()(和前面创建局部变量一样)创建新的 TensorVariable 来保存结果,中间收集到的 Guard 信息也附加了上去,最后在 call_function()函数中将结果压栈。到此完成当前字节码的模拟执行并在此过程中将对应的 Proxy 添加到 FX Graph 中。
因此,TorchDynamo 在字节码的分析过程中并没有真正地执行指令,而是以符号分析的方式从字节码中提取相应的符号和函数,创建相应的 Proxy 并添加到 FX Graph 中,通过指令逐条模拟执行(解析)不断构建 FX Graph,直到触发 compile_subgraph()。
5 子图编译
在前面的分析中,TorchDynamo 都是逐条执行指令然后不断地构建 FX Graph,但当遇到例如 jump 字节码指令(对应 if/else、循环等)时,会触发compile_subgraph()
函数,因为在 TorchDynamo 中是以子图为单位进行编译的(除了设置full_graph=True
),在compile_subgraph
()中完成 FX Graph 一个完整子图的构建,并调用 backend compiler 对该子图进行编译。 分析compile_subgraph
()的核心函数compile_and_call_fx_graph
()实现:
调用
create_node
创建类型为 output 的 Proxy(对应输出返回值),到此一张完整的 FX Graph 构建完毕。基于完整的 FX Graph 创建对应的 GraphModule 作为编译函数的入参,并通过
call_user_compiler
()函数调用 backend compiler 对 GraphModule 进行编译(在这里开始进入 inductor 编译函数)。通过
PyCodegen.get_instructions()
函数获得编译后函数对应的 Instructions,到此完成整个子图的编译部分。
最后调用add_output_instructions
()函数将 should_exit 属性置为 True,这意味回到InstructionTranslatorBase.run()
循环中会退出。 一路回退到_compile(),在transform_code_object()
中会调用clean_and_assemble_instructions()
将 Instruction 汇编为 Python 可执行的字节码。
6 Guard 生成
对于静态图的生成,特别是 static shape,TorchDynamo 中还有一个重要的组成部分——Guard。随着上述提及的函数调用栈一路回退到_compile(),此时已经完成了 FX Graph 的构建、调用 backend compiler 进行了编译并对编译后代码生成字节码,TorchDynamo 最后需要为之前构建 FX Graph 过程中收集的 Guard 生成检测代码(Python 代码),从而在后续执行代码时检测代码是否已编译过。 TorchDynamo 通过CheckFunctionManager
的compile_check_fn()
函数为 Guard 生成可执行 Python 代码,为了降低运行时检测输入是否发生变化的函数开销,TorchDynamo 把 Guard 检测功能实现在了 C++中。(具体实现可以查阅 /usr/local/lib/python3.9/dist-packages/torch/_dynamo/guards.py ) 回到_compile()
函数,check_fn
即生成的 Guard 检查函数,GuardedCode
保存编译好的子图out_code
和check_fn
。
到此完成了一个完整的子图的全部构建和编译工作,最终回到最开始的_custom_eval_frame
()函数,对编译完的代码调用eval_custom_code
(),送入 CPython 默认的执行函数入口_PyEval_EvalFrameDefault
进行执行,完成编译后子图的执行(和 Pytorch eager 模式执行一样)。
7 后续函数的执行过程
回到_custom_eval_frame
()函数,此时拿到了编译好的 GuardedCode,create_cache_entry()和 set_extra()往当前用户函数的 frame->f_code 里写入了一跳 CacheEntry,记录了 check_fn 和编译好的 code。eval_custom_code
()创建了一个新的 Python Frame,并运行编译好的函数。 eval_custom_code
() 中直接调用了eval_frame_default
()来执行上面的字节码,所以此处不会再次触发 TorchDynamo 定制的 Frame Evaluation 函数custom_eval_frame_shim
()。执行完编译过的子图,程序返回到 Python 解释器,下一条字节码是 if/else 对应的跳转指令,会再次触发 TorchDynamo 设置的 Frame Evaluation 函数 custom_eval_frame_shim
()继续进行子图的捕获和编译。
评论