写点什么

Ascend 的 aclgraph(三)TorchDynamo

作者:zjun
  • 2025-05-19
    上海
  • 本文字数:16186 字

    阅读完需:约 53 分钟

Ascend的aclgraph(三)TorchDynamo

1 基本介绍

参考:【编译系列】Torch.compile()流程解析——2. TorchDynamo针对 TorchDynamo 的介绍,该已经讲的比较仔细,本篇文章,主要是基于原文并加上自己的理解介绍。


在上一篇Ascend的aclgraph(2)_npu_backend中还有些什么秘密?,解释了 torch.compile 出现的背景并初步了解了其使用和基础组。


先回顾下 torch.compile 主要包含四个基础组件:


  • TorchDynamo:从 python bytecode 中解析构建计算图,是一个动态的、Python 级别的编译器,旨在捕捉 PyTorch 模型的动态执行路径,将其转换为优化代码,实现 bytecode-to-bytecode 编译。

  • AOTAutograd:是 PyTorch 引入的一种自动求导机制,旨在在模型执行之前预先生成梯度计算的代码。这种方法通过静态分析模型的前向计算图,提前生成反向传播所需的梯度计算逻辑,从而减少运行时的开销,提升训练效率。

  • PrimTorch:将 2000+ Pytorch op 规范化为约 250 个原始 op 封闭集合,是 PyTorch 的一个中间表示(Intermediate Representation, IR)层,负责将高层的 PyTorch 操作转换为更低层次的、适合进一步优化和编译的基础操作。它通过简化和标准化操作,提高编译器和后端优化器处理计算图的效率。

  • TorchInductor:深度学习编译器,可为多种加速器和后端生成代码,生成 OpenAI Triton(Nvidia/AMD GPU)和 OpenMP/C++(CPU)代码,它利用多种优化技术,包括内存优化、并行化和低层次的代码生成,以最大化计算性能。


本章我们将解释四个基础组件中的TorchDynamo

2 TorchDynamo

PyTorch Dynamo 是 torch.compile 编译器的前端,能够将 eage 代码编译为 FX Graph,进而提交给 lowering 的编译器(Inductor 等)进行编译,最终生成优化过后的底层机器代码,达到加速的效果。TorchDynamo 的基本工作流程是基于 PEP 523(Python Enhancement Proposal)在函数执行前拿到 Python 字节码,通过解析并模拟执行每条 Python 字节码逐步创建 FX Graph,对 if/else、loop 或不支持的操作,会触发 graph break 生成 sub-graph。主要具备一下功能:


  1. 在代码执行前从 PVM 层拦截用户代码(基于 PEP 523)

  2. 模拟用户代码直接, 理解代码逻辑

  3. 最后将其他编译组件优化后的代码交还给 PVM。


例如对于上面的 my_func()函数会生成三个 subgraph。


def my_func(x, y):    if x.sum() > y.sum():        loss = torch.cos(torch.cos(x))    else:        loss = torch.cos(torch.cos(y))    return loss
复制代码


其中opcode有六种:


  1. placeholder 对应输入,

  2. call_method/call_function/call_module 对应函数/方法/模型调用,

  3. utput 对应输出,

  4. target 是函数调用,

  5. name 是 op 结果名称,

  6. args 和 kwargs 是参数。

2.1 call_method/call_function

在 PyTorch 的 torch.compile 中,特别是在使用 torch.fx 进行模型追踪和转换时,call_method 和 call_function 是两种不同的操作类型,它们用于表示计算图中的不同节点。理解这两者的区别对于有效地利用 torch.fx 进行模型优化和变换非常重要。


call_method 定义:call_method 用来调用某个对象的方法。这意味着你有一个特定的对象实例,并且想要调用该实例上的一个方法。使用场景:当你需要对一个对象(如张量、模块等)调用其定义的方法时使用。例如,调用一个张量的 .size() 或 .reshape() 方法。


示例:python浅色版本# 假设我们有一个张量 xx = torch.tensor([1, 2, 3])# 使用 call_method 调用 x 的 size 方法graph.call_method("size", args=(x,))
复制代码


call_function 定义:call_function 用于直接调用 Python 函数。这与 call_method 不同,因为它不依赖于任何特定的对象实例。使用场景:当你想调用一个独立的函数,比如内置的数学运算函数或其他不依赖于特定对象的方法时使用。例如,torch.add 或者 torch.relu。


示例:python浅色版本# 直接调用 torch.add 函数graph.call_function(torch.add, args=(x, y))
复制代码


主要区别


  • 依赖性:call_method 需要一个对象实例作为调用上下文,而 call_function 则不需要,它可以直接调用全局或局部定义的函数。

  • 灵活性:由于 call_function 不依赖于特定对象,因此它可以提供更大的灵活性,尤其是在进行函数式编程或编写与特定类无关的通用代码时。

  • 表达方式:两者在计算图中表示的方式不同,这也影响了如何对它们进行变换和优化。理解这些差异有助于更好地操控计算图,实现更高效的模型优化。


总之,call_method 和 call_function 在 torch.fx 中分别用于表示对对象方法的调用和对自由函数的调用,正确地使用它们可以提高模型追踪和转换的准确性和效率。

2.2 示例

# 原始函数def my_func(x, y):    if x.sum() > y.sum():        loss = torch.cos(torch.cos(x))    else:        loss = torch.cos(torch.cos(y))    return loss    ===============my compiler=================# 第1个subgraph:判断语句及之前的代码opcode         name    target                  args            kwargs-------------  ------  ----------------------  --------------  --------placeholder    l_x_    L_x_                    ()              {}placeholder    l_y_    L_y_                    ()              {}call_method    sum_1   sum                     (l_x_,)         {}call_method    sum_2   sum                     (l_y_,)         {}call_function  gt      <built-in function gt>  (sum_1, sum_2)  {}output         output  output                  ((gt,),)        {}
# 对应的python代码code is: def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor): l_x_ = L_x_ l_y_ = L_y_ sum_1 = l_x_.sum(); l_x_ = None sum_2 = l_y_.sum(); l_y_ = None gt = sum_1 > sum_2; sum_1 = sum_2 = None return (gt,)
===============my compiler=================# 第2个subgraph:if为True对应一个subgraphopcode name target args kwargs------------- ------ ------------------------------------------------------ ---------- --------placeholder l_x_ L_x_ () {}call_function cos <built-in method cos of type object at 0x7f0f1017b500> (l_x_,) {}call_function loss <built-in method cos of type object at 0x7f0f1017b500> (cos,) {}output output output ((loss,),) {}# 对应的python代码code is: def forward(self, L_x_ : torch.Tensor): l_x_ = L_x_ cos = torch.cos(l_x_); l_x_ = None loss = torch.cos(cos); cos = None return (loss,)
===============my compiler=================# 第3个subgraph:if为False对应一个subgraphopcode name target args kwargs------------- ------ ------------------------------------------------------ ---------- --------placeholder l_y_ L_y_ () {}call_function cos <built-in method cos of type object at 0x7f1254470500> (l_y_,) {}call_function loss <built-in method cos of type object at 0x7f1254470500> (cos,) {}output output output ((loss,),) {}# 对应的python代码code is:def forward(self, L_y_ : torch.Tensor): l_y_ = L_y_ cos = torch.cos(l_y_); l_y_ = None loss = torch.cos(cos); cos = None return (loss,)
复制代码


TorchDynamo 工作示意图



优点


  • 动态优化:能够处理包含动态控制流的模型,如循环和条件语句,适用于动态计算图。

  • 自动化:用户无需手动编写转换代码,可以自动识别和优化执行路径。

  • 灵活性:支持多种后端优化器,如 TorchInductor,提供多样化的性能提升方案。


缺点


  • 工具:作为较新的优化工具,可能在某些极端场景下还不够稳定或全面支持所有 PyTorch 功能,尚未提供序列化/反序列化 API。

  • 依赖后端:最终性能提升依赖于所使用的后端优化器,某些后端可能在特定硬件或模型上表现不佳。 和其他静态图构建方式相比,TorchDynamo 更为灵活且支持更多复杂的操作,而不需要用户做大量的代码修改适配


torch 的几种成图方式对比


3 CPython 代码的执行过程 & PEP 523:

在正式进入 TorchDynamo 工作过程之前先了解 CPython 的工作流程。首先介绍 Cpython 中两个重要的对象——PyCodeObjectPyFrameObject


  • PyCodeObject保存二进制字节码、常量表、变量名表等静态信息;

  • PyFrameObject 是一个用于表示执行环境的对象,每次函数调用时,Python 都会创建一个新的PyFrameObject,其中包含了该函数的 PyCodeObject,以及一些其他运行需要的信息,如存放局部变量的内存空间和evaluation stake(函数调用栈)等。


CPython 在执行 Python 函数前会将 Python 代码编译为字节码,由 Python 虚拟机(PVM)中_PyEval_EvalFrameDefault()函数逐条执行编译好的字节码,而 PEP 523 提供了一个 API 接口让用户在 PVM 执行字节码之前获得待执行的字节码,从而可以对字节码进行优化修改实现即时编译(JIT Compiler)的效果。 TorchDynamo 正是基于 PEP 523 把 TorchDynamo 的编译逻辑引入到 Python 代码的解释执行过程中。


static PyObject* _custom_eval_frame_shim(    PyThreadState* tstate,    THP_EVAL_API_FRAME_OBJECT* frame,    int throw_flag) {  // Shims logic into one of three states. Can probably be refactored into a  // single func, later:  //  - None: disables TorchDynamo  //  - False: run-only mode (reuse existing compiles)  //  - Python callable(): enables TorchDynamo  PyObject* callback = eval_frame_callback_get();
if (callback == Py_None) { return eval_frame_default(tstate, frame, throw_flag); } return _custom_eval_frame(tstate, frame, throw_flag, callback); # 调用编译函数}
// ------------------------------------------------------------------------------------------------------------------------------------static PyObject* _custom_eval_frame( PyThreadState* tstate, THP_EVAL_API_FRAME_OBJECT* frame, int throw_flag, PyObject* callback) { // 省略中间代码,只展示核心函数调用...
PyObject* maybe_cached_code = lookup(extra, frame, NULL); if (maybe_cached_code == NULL) { // Python error return NULL; } else if (maybe_cached_code != Py_None) { PyCodeObject* cached_code = (PyCodeObject*)maybe_cached_code; // used cached version DEBUG_TRACE("cache hit %s", name(frame)); // Re-enable custom behavior eval_frame_callback_set(callback); return eval_custom_code(tstate, frame, cached_code, throw_flag); //命中cache,直接eval_custom_code函数执行frame中的代码 } // cache miss PyObject* result = call_callback(callback, frame, cache_size(extra)); // 未命中则调用回调函数进行编译 if (result == NULL) { return NULL; } else if (result != Py_None) { DEBUG_TRACE("create cache %s", name(frame)); extra = create_cache_entry(extra, result); Py_DECREF(result); set_extra(frame->f_code, extra); // 将编译完成代码添加到frame中 // Re-enable custom behavior eval_frame_callback_set(callback); return eval_custom_code(tstate, frame, extra->code, throw_flag); } }
复制代码


通过 CPython 提供的_PyInterpreterState_SetEvalFrameFunc()函数把 CPython 中用于执行字节码的默认函数给替换为custom_eval_frame_shim()。 在执行用户想要编译的函数时便会进入_custom_eval_frame_shim().


_custom_eval_frame函数中,会先通过lookup函数检查 cache 中是否有已编译代码,若存在则直接调用eval_custom_code函数执行,从而避免重复编译相同函数。若 cache 未命中,则通过call_callback调用回调函数进行编译,并通过set_extra()将编译结果保存在 PyFrameObject 中,最后调用eval_custom_code继续进行执行。而这里的回调函数也即前面在torch._dynamo.optimize传入的回调函数convert_frame.convert_frame(backend, hooks=hooks)(包含编译入口 compile_fn)。


因此 torch.compile 只有在第一次正式执行代码前才会进行编译,这也导致测试编译代码的时间时需要考虑数据预热。


到此,解释了torch.compile是如何在 Python 代码执行过程中引入TorchDynamo的,接下来回到torch._dynamo.optimize解析是如何一步步从字节码构建 FX Graph。

4 TorchDynamo 模拟执行 & FX Graph 构建

回到torch._dynamo.optimize设置的回调函数convert_frame.convert_frame(backend, hooks=hooks),其核心函数是_compile(),用于负责对字节码进行编译。先看代码。


#ps:只展示核心函数调用# path:torch/_dynamo/convert_frame.py@compile_time_strobelight_meta(phase_name="_compile")@_use_lazy_graph_module(config.use_lazy_graph_module)def _compile(    code: types.CodeType,    globals: Dict[str, object],    locals: Dict[str, object],    builtins: Dict[str, object],    compiler_fn: CompilerFn,    one_graph: bool,    export: bool,    export_constraints,    hooks: Hooks,    cache_entry,    cache_size: CacheSizeRelevantForFrame,    frame: Optional[types.FrameType] = None,    frame_state=None,    compile_id=None,    *,    skip: int = 0,) -> Optional[GuardedCode]:    exceeded, limit_type = exceeds_cache_size_limit(cache_size)    # 判断缓存大小是否超过阈值,默认为64    try:        guarded_code = compile_inner(code, one_graph, hooks, transform)        return guarded_code    except ...    finally:            tracer.output.call_cleanup_hooks()
output = tracer.output assert output is not None assert output.output_instructions instructions[:] = output.output_instructions code_options.update(output.code_options)
if config.dead_code_elimination: propagate_inst_exn_table_entries(instructions) check_inst_exn_tab_entries_valid(instructions) instructions[:] = remove_pointless_jumps(remove_dead_code(instructions)) # 根据FX Graph分析未调用的代码,并进行剔除 #--------------------------------------------------------------------------------------------------------------------------------- def compile_inner( code: types.CodeType, one_graph: bool, hooks: Hooks, transform: Callable[[List[Instruction], Dict[str, Any]], Any], ) -> Optional[GuardedCode]: for attempt in itertools.count(): CompileContext.get().attempt = attempt try: out_code = transform_code_object(code, transform) # 编译用户想要优化的函数 break except ...
#-------------------------------------------------------------------------------------------------------------------------------def transform_code_object(code, transformations, safe=False) -> types.CodeType: keys = get_code_keys() code_options = {k: getattr(code, k) for k in keys} assert len(code_options["co_varnames"]) == code_options["co_nlocals"]
instructions = cleaned_instructions(code, safe) propagate_line_nums(instructions)
transformations(instructions, code_options) return clean_and_assemble_instructions(instructions, keys, code_options)[1]
def cleaned_instructions(code, safe=False) -> List[Instruction]: instructions = list(map(convert_instruction, dis.get_instructions(code))) check_offsets(instructions) if sys.version_info >= (3, 11): populate_kw_names_argval(instructions, code.co_consts) virtualize_exception_table(code.co_exceptiontable, instructions) virtualize_jumps(instructions) strip_extended_args(instructions) if not safe: if sys.version_info < (3, 11): remove_load_call_method(instructions) if sys.version_info < (3, 12): explicit_super(code, instructions) if sys.version_info >= (3, 11): remove_jump_if_none(instructions) if sys.version_info >= (3, 12): remove_binary_store_slice(instructions) update_offsets(instructions) devirtualize_jumps(instructions) return instructions
复制代码


根据如上的代码逻辑,抽象出如下的流程图。



_compile()中会对缓存大小进行判断,如果缓存大小超过配置会有警告信息,默认值为 64,含义是对于同一个 Python 函数,如果函数的输入张量信息组合变化超过 64 种,TorchDynamo 则不会继续编译用户指定的函数。 在_compile 中通过transform_code_object(code, transform)对用户代码进行优化转换,其中codePyCodeObject类型,即待编译优化的字节码,transformer是转换函数。在transform_code_object函数中,cleaned_instructions()用来预处理字节码指令,通过 Python 标准库 dis.get_instructions(code) 获取字节码指令,对字节码进行清洗(例如对跳转指令做标准化处理),并转为结构化数据表示 Instruction,方便后续的优化。


例如对我们的样例函数 my_func 进行 cleaned_instructions()后的结果如下。


Instruction(opcode=113, opname='JUMP_ABSOLUTE', arg=18, argval=18, offset=0, starts_line=13, is_jump_target=False, positions=None, target=Instruction(opcode=116, opname='LOAD_GLOBAL', arg=1, argval='torch', offset=18, starts_line=14, is_jump_target=True, positions=None, target=None, exn_tab_entry=None), exn_tab_entry=None)Instruction(opcode=124, opname='LOAD_FAST', arg=0, argval='x', offset=2, starts_line=None, is_jump_target=False, positions=None, target=None, exn_tab_entry=None)Instruction(opcode=106, opname='LOAD_ATTR', arg=0, argval='sum', offset=4, starts_line=None, is_jump_target=False, positions=None, target=None, exn_tab_entry=None)Instruction(opcode=131, opname='CALL_FUNCTION', arg=0, argval=0, offset=6, starts_line=None, is_jump_target=False, positions=None, target=None, exn_tab_entry=None)Instruction(opcode=124, opname='LOAD_FAST', arg=1, argval='y', offset=8, starts_line=None, is_jump_target=False, positions=None, target=None, exn_tab_entry=None)Instruction(opcode=106, opname='LOAD_ATTR', arg=0, argval='sum', offset=10, starts_line=None, is_jump_target=False, positions=None, target=None, exn_tab_entry=None)Instruction(opcode=131, opname='CALL_FUNCTION', arg=0, argval=0, offset=12, starts_line=None, is_jump_target=False, positions=None, target=None, exn_tab_entry=None)Instruction(opcode=107, opname='COMPARE_OP', arg=4, argval='>', offset=14, starts_line=None, is_jump_target=False, positions=None, target=None, exn_tab_entry=None)Instruction(opcode=114, opname='POP_JUMP_IF_FALSE', arg=36, argval=36, offset=16, starts_line=None, is_jump_target=False, positions=None, target=Instruction(opcode=116, opname='LOAD_GLOBAL', arg=1, argval='torch', offset=36, starts_line=16, is_jump_target=True, positions=None, target=None, exn_tab_entry=None), exn_tab_entry=None)Instruction(opcode=116, opname='LOAD_GLOBAL', arg=1, argval='torch', offset=18, starts_line=14, is_jump_target=True, positions=None, target=None, exn_tab_entry=None)Instruction(opcode=106, opname='LOAD_ATTR', arg=2, argval='cos', offset=20, starts_line=None, is_jump_target=False, positions=None, target=None, exn_tab_entry=None)Instruction(opcode=116, opname='LOAD_GLOBAL', arg=1, argval='torch', offset=22, starts_line=None, is_jump_target=False, positions=None, target=None, exn_tab_entry=None)Instruction(opcode=106, opname='LOAD_ATTR', arg=2, argval='cos', offset=24, starts_line=None, is_jump_target=False, positions=None, target=None, exn_tab_entry=None)Instruction(opcode=124, opname='LOAD_FAST', arg=0, argval='x', offset=26, starts_line=None, is_jump_target=False, positions=None, target=None, exn_tab_entry=None)Instruction(opcode=131, opname='CALL_FUNCTION', arg=1, argval=1, offset=28, starts_line=None, is_jump_target=False, positions=None, target=None, exn_tab_entry=None)Instruction(opcode=131, opname='CALL_FUNCTION', arg=1, argval=1, offset=30, starts_line=None, is_jump_target=False, positions=None, target=None, exn_tab_entry=None)Instruction(opcode=125, opname='STORE_FAST', arg=2, argval='loss', offset=32, starts_line=None, is_jump_target=False, positions=None, target=None, exn_tab_entry=None)Instruction(opcode=110, opname='JUMP_FORWARD', arg=16, argval=52, offset=34, starts_line=None, is_jump_target=False, positions=None, target=Instruction(opcode=124, opname='LOAD_FAST', arg=2, argval='loss', offset=52, starts_line=17, is_jump_target=True, positions=None, target=None, exn_tab_entry=None), exn_tab_entry=None)Instruction(opcode=116, opname='LOAD_GLOBAL', arg=1, argval='torch', offset=36, starts_line=16, is_jump_target=True, positions=None, target=None, exn_tab_entry=None)Instruction(opcode=106, opname='LOAD_ATTR', arg=2, argval='cos', offset=38, starts_line=None, is_jump_target=False, positions=None, target=None, exn_tab_entry=None)Instruction(opcode=116, opname='LOAD_GLOBAL', arg=1, argval='torch', offset=40, starts_line=None, is_jump_target=False, positions=None, target=None, exn_tab_entry=None)Instruction(opcode=106, opname='LOAD_ATTR', arg=2, argval='cos', offset=42, starts_line=None, is_jump_target=False, positions=None, target=None, exn_tab_entry=None)Instruction(opcode=124, opname='LOAD_FAST', arg=1, argval='y', offset=44, starts_line=None, is_jump_target=False, positions=None, target=None, exn_tab_entry=None)Instruction(opcode=131, opname='CALL_FUNCTION', arg=1, argval=1, offset=46, starts_line=None, is_jump_target=False, positions=None, target=None, exn_tab_entry=None)Instruction(opcode=131, opname='CALL_FUNCTION', arg=1, argval=1, offset=48, starts_line=None, is_jump_target=False, positions=None, target=None, exn_tab_entry=None)Instruction(opcode=125, opname='STORE_FAST', arg=2, argval='loss', offset=50, starts_line=None, is_jump_target=False, positions=None, target=None, exn_tab_entry=None)Instruction(opcode=124, opname='LOAD_FAST', arg=2, argval='loss', offset=52, starts_line=17, is_jump_target=True, positions=None, target=None, exn_tab_entry=None)Instruction(opcode=83, opname='RETURN_VALUE', arg=None, argval=None, offset=54, starts_line=None, is_jump_target=False, positions=None, target=None, exn_tab_entry=None)
复制代码


清理后的字节码通过transformations(instructions, code_options)函数进行处理,核心 transform()函数实现如下。


def transform(instructions, code_options):    nonlocal output    nonlocal tracer    speculation_log.restart()    tracer = InstructionTranslator(        instructions,        code,        locals,        globals,        builtins,        code_options,        compiler_fn,        one_graph,        export,        export_constraints,        mutated_closure_cell_contents,        frame_state=frame_state,        speculation_log=speculation_log,    )    try:        with tracing(tracer.output.tracing_context), tracer.set_current_tx():            tracer.run()    except ...
复制代码


首先实例化了InstructionTranslator对象,在InstructionTranslator中有一个 OutputGraph 的实例,用于保存InstructionTranslator编译后的输出,以 torch.fx.Graph表示。


其中transform()(也即 TorchDynamo)构建 FX Graph 的核心模块有两个:


  • InstructionTranslator的初始化过程,负责对变量构建对应的 Proxy,对应 FX Graph 中的placeholder部分; -

  • InstructionTranslator.run()负责模拟运行字节码并构建对应的 Node 添加到 FX Graph 中;


InstructionTranslator的初始化过程中,通过 PyCodeObject 对象的 co_varnames 字段获取待编译函数中的变量名,并为每一个变量创建一个LazyVariableTracker,作为symbolic_locals。其中LazyVariableTracker是一种推迟创建给定底层值的VariableTracker,直到访问该值才创建用于节省空间资源,而VariableTracker被用于记录每个 Python 变量对应的类型信息用于构建静态图。


vars = list(code_options["co_varnames"])cells_and_freevars = [x for x in self.cell_and_freevars() if x not in vars]vars.extend(cells_and_freevars)cells_and_freevars_set = set(cells_and_freevars)
self.symbolic_locals = { k: variables.LazyVariableTracker.create( f_locals[k], source=LocalSource(k, cell_or_freevar=k in cells_and_freevars_set), ) for k in vars if k in f_locals}
复制代码


其中LazyVariableTracker通过 VariableBuilder 来生成实际对象。以torch.Tensor为例,VariableBuilder 的创建逻辑如下:


tensor_proxy = self.tx.output.root_tracer.create_graph_input(    re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value), source=source)options = {}if type(value) in config.traceable_tensor_subclasses:    options["torch_function_fn"] = build_torch_function_fn(        self.tx, value, self.source    )    self.install_guards(GuardBuilder.TYPE_MATCH)
tensor_variable = wrap_fx_proxy( tx=self.tx, proxy=tensor_proxy, example_value=value, subclass_type=subclass_type, source=source, **options,)
guard_type = GuardBuilder.TENSOR_MATCH
if isinstance(source, GradSource) and is_from_optimizer_source(source): guard_type = GuardBuilder.NOT_NONE_MATCH
self.install_guards( functools.partial( guard_type, value=value if isinstance(source, NumpyTensorSource) else TensorWeakRef(value), ))

#----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------# 节省篇幅,这里只展示核心函数调用def wrap_fx_proxy_cls( target_cls, tx, proxy, example_value=None, subclass_type=None, **options): example_value = wrap_to_fake_tensor_and_record(example_value, tx=tx, **kwargs) if isinstance(example_value, torch.Tensor): set_example_value(proxy.node, example_value) specialized_props = target_cls.specialize(example_value) options.update(specialized_props) return target_cls(proxy, **options)
复制代码


  1. 首先通过create_graph_input()在 FX Graph 中创建了类型为 placeholder 的 FX Proxy(FX Proxy 是 FX symbolic tracing 中的 symbol,placeholder 对应变量,即前面的 opcode)。

  2. install_guards()函数创建了类型为 GuardBuilder.TYPE_MATCH 的 Guard 对象,Guard 在 TorchDynamo 中负责检测被编译函数所引用的外部数据信息是否发生变化,如果没有发生变化则可以复用之前编译好的函数,否则需要重新编译该函数。TYPE_MATCH 主要判断两者的数据类型是否一致,TENSOR_MATCH 主要对输入 Tensor 的 shape、stride 等信息进行检查是否发生改变

  3. wrap_fx_proxy()为刚刚创建的 Proxy 建立实际的 VariableTracker,核心逻辑实现在wrap_fx_proxy_cls()函数:

  4. wrap_fx_proxy_cls()函数中首先通过wrap_to_fake_tensor_and_record()函数为运行时获得的 torch.Tensor 创建 FakeTensor(默认情况下,TorchDynamo 使用 FakeTensor 捕获计算图而不是真实的 torch.Tensor,FakeTensor 具有和真实 torch.Tensor 相同的张量信息,但没有实际的数据和张量内存分配);

  5. 通过specialize()函数特化张量信息(包括 dtype、device 等),在 static shape 模式下还会特化 size、stride、is_contiguous 信息,而在 dynamic shape 模式下则不会特化这部分信息。

  6. 最后通过target_cls创建对应的VariableTracker对象,例如这里的是 torch.Tensor,则创建的是 TensorVariable(VariableTracker 的子类),用于记录 Pytorch 中的 torch.Tensor 类型数据的相关信息。


因此,在InstructionTranslator对象初始化创建VariableTracker的过程中,TorchDynamo 完成了:


  • 在 FX Graph 中创建 FXProxy

  • 添加 Guard 和 FakeTensor 相关操作并初始化 VariableTracker,由于并不是所有的局部变量都会被当前 frame 用到,为了节省资源开销这里采用 LazyVariableTracker,只有到实际使用的时候才会进行创建。


到此完成输入对应的 VariableTracker 创建,会在后续一直带着 Guard、FakeTensor 等信息用于跟踪 Tensor 的后续操作。


完成InstructionTranslator对象的初始化,回到InstructionTranslator.run()函数,由于InstructionTranslator继承于class InstructionTranslatorBase,所以这里实际调用的是InstructionTranslatorBase.run()函数。InstructionTranslatorBase 的本质是一个 Python 虚拟机的模拟器,在循环中对字节码逐条解析模拟执行加粗样式,对其核心函数step()进行分析,首先基于 instruction_pointer 获取待执行的字节码指令,通过 dispatch_table 映射表获取到每个 op 对应的函数调用并进入函数解析当前字节码指令,当遇到循环、if/else 等跳转相关的字节码指令时会触发 compile_subgraph()函数进入子图编译相关操作。代码如下:


def run(self):    with self.run_ctx_mgr():        try:            self.output.push_tx(self)            while self.step():        # 循环调用step()函数模拟执行python字节码                pass        except BackendCompilerFailed:            raise
# ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------def step(self): """Process exactly one instruction, return False we should exit""" ip = self.instruction_pointer # Python虚拟机的PC(Program Counter),表明当前正在执行的字节码指令所处位置; if ip is None: return False self.current_instruction = inst = self.instructions[ip] self.instruction_pointer = ip + 1
if inst.starts_line: self.starts_line(inst.starts_line)
# stack:Python虚拟机的数据栈,Python虚拟机中字节码之间通过数据栈交换数据 if ( not self.stack and self.should_compile_partial_graph() and self.is_non_empty_graph() ): self.current_speculation = self.speculate() # 循环、if/else等语句都会触发,进入step_graph_break进行子图编译 if self.current_speculation.failed: return self.step_graph_break(inst)
self.update_block_stack(inst)
try: self.dispatch_table[inst.opcode](self, inst) # 逐条执行指令 return not self.output.should_exit # 执行结束退出 except exc.ObservedException: self.exception_handler() return True except ReturnValueOp: return False # 返回指令 except Unsupported: if self.current_speculation is None: log.debug("empty checkpoint") raise log.debug("step triggered compile", exc_info=True)
复制代码


对于每条字节码的模拟执行和解析,以下面的CALL_FUNCTION函数调用为例,会先根据argval弹出对应的函数参数,并进一步调用TensorVariable.call_method()函数。在call_method()函数中,proxy_args_kwargs()函数从symbolic_locals中获取相应的函数参数 Proxy,然后调用create_proxy()创建新的 Proxy,类型是call_method,并有对应的 method 名(如 my_func 中的 x.sum()对应 TorchDynamo 中的 target 项)和参数。最后通过wrap_fx_proxy()(和前面创建局部变量一样)创建新的 TensorVariable 来保存结果,中间收集到的 Guard 信息也附加了上去,最后在 call_function()函数中将结果压栈。到此完成当前字节码的模拟执行并在此过程中将对应的 Proxy 添加到 FX Graph 中。


@break_graph_if_unsupported(push=1)def CALL_FUNCTION(self, inst):    args = self.popn(inst.argval)    fn = self.pop()    self.call_function(fn, args, {})
# ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------# TensorVariabledef call_method( self, tx, name, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]",) -> "VariableTracker": # 省略中间代码... return wrap_fx_proxy( tx, tx.output.create_proxy( "call_method", name, *proxy_args_kwargs([self, *args], kwargs), ), )
复制代码


因此,TorchDynamo 在字节码的分析过程中并没有真正地执行指令,而是以符号分析的方式从字节码中提取相应的符号和函数,创建相应的 Proxy 并添加到 FX Graph 中,通过指令逐条模拟执行(解析)不断构建 FX Graph,直到触发 compile_subgraph()。

5 子图编译

在前面的分析中,TorchDynamo 都是逐条执行指令然后不断地构建 FX Graph,但当遇到例如 jump 字节码指令(对应 if/else、循环等)时,会触发compile_subgraph()函数,因为在 TorchDynamo 中是以子图为单位进行编译的(除了设置full_graph=True),在compile_subgraph()中完成 FX Graph 一个完整子图的构建,并调用 backend compiler 对该子图进行编译。 分析compile_subgraph()的核心函数compile_and_call_fx_graph()实现:


  • 调用create_node创建类型为 output 的 Proxy(对应输出返回值),到此一张完整的 FX Graph 构建完毕。

  • 基于完整的 FX Graph 创建对应的 GraphModule 作为编译函数的入参,并通过call_user_compiler()函数调用 backend compiler 对 GraphModule 进行编译(在这里开始进入 inductor 编译函数)。

  • 通过PyCodegen.get_instructions()函数获得编译后函数对应的 Instructions,到此完成整个子图的编译部分。


  # ps:只展示核心函数调用,省略中间过程 def compile_and_call_fx_graph(self, tx, rv, root):   """ Generate code from self.graph and return the Instruction()s to call that generated code. """   self.create_node( "output", "output", (self.current_tracer.create_arg(tuple(x.as_proxy() for x in rv)),), {}, )   gm = _make_graph_module(root, self.graph) # 创建GraphModule,对应编译函数的传入参数,用于编译
with self.restore_global_state(): compiled_fn = self.call_user_compiler(gm) # 这里调用用户指定的backend compiler进行编译,如inductor等
compiled_fn = disable(compiled_fn) # 禁止TorchDynamo再次编译已编译的函数
cg = PyCodegen(tx)cg.make_call_generated_code(name) # 生成字节码return cg.get_instructions() # 获取编译后代码的Instructions
复制代码


最后调用add_output_instructions()函数将 should_exit 属性置为 True,这意味回到InstructionTranslatorBase.run()循环中会退出。 一路回退到_compile(),在transform_code_object()中会调用clean_and_assemble_instructions()将 Instruction 汇编为 Python 可执行的字节码。

6 Guard 生成

对于静态图的生成,特别是 static shape,TorchDynamo 中还有一个重要的组成部分——Guard。随着上述提及的函数调用栈一路回退到_compile(),此时已经完成了 FX Graph 的构建、调用 backend compiler 进行了编译并对编译后代码生成字节码,TorchDynamo 最后需要为之前构建 FX Graph 过程中收集的 Guard 生成检测代码(Python 代码),从而在后续执行代码时检测代码是否已编译过。 TorchDynamo 通过CheckFunctionManagercompile_check_fn()函数为 Guard 生成可执行 Python 代码,为了降低运行时检测输入是否发生变化的函数开销,TorchDynamo 把 Guard 检测功能实现在了 C++中。(具体实现可以查阅 /usr/local/lib/python3.9/dist-packages/torch/_dynamo/guards.py ) 回到_compile()函数,check_fn即生成的 Guard 检查函数,GuardedCode保存编译好的子图out_codecheck_fn


check_fn = CheckFunctionManager(    output,    hooks.guard_fail_fn if hooks else None,)
guarded_code = GuardedCode(out_code, check_fn.check_fn)
复制代码


到此完成了一个完整的子图的全部构建和编译工作,最终回到最开始的_custom_eval_frame()函数,对编译完的代码调用eval_custom_code(),送入 CPython 默认的执行函数入口_PyEval_EvalFrameDefault进行执行,完成编译后子图的执行(和 Pytorch eager 模式执行一样)。

7 后续函数的执行过程

回到_custom_eval_frame()函数,此时拿到了编译好的 GuardedCode,create_cache_entry()和 set_extra()往当前用户函数的 frame->f_code 里写入了一跳 CacheEntry,记录了 check_fn 和编译好的 code。eval_custom_code()创建了一个新的 Python Frame,并运行编译好的函数。 eval_custom_code() 中直接调用了eval_frame_default()来执行上面的字节码,所以此处不会再次触发 TorchDynamo 定制的 Frame Evaluation 函数custom_eval_frame_shim()。执行完编译过的子图,程序返回到 Python 解释器,下一条字节码是 if/else 对应的跳转指令,会再次触发 TorchDynamo 设置的 Frame Evaluation 函数 custom_eval_frame_shim()继续进行子图的捕获和编译。

用户头像

zjun

关注

还未添加个人签名 2020-03-06 加入

还未添加个人简介

评论

发布
暂无评论
Ascend的aclgraph(三)TorchDynamo_PyTorch_zjun_InfoQ写作社区