写点什么

Ascend 的 aclgraph(五)PrimTorch & TorchInductor

作者:zjun
  • 2025-05-19
    上海
  • 本文字数:5086 字

    阅读完需:约 17 分钟

Ascend的aclgraph(五)PrimTorch & TorchInductor

1 PrimTorch

参考:Torch.compile()流程解析——4. PrimTorch & TorchInductor


在上一篇Ascend的aclgraph(四)AOT Autograd构建 joint graph 的时候提及过 op 执行的时候,通过ProxyTorchDispatchModetorch_dispatch对 op 进行decompose,具体流程是:


  1. 调用maybe_handle_decomp()函数在CURRENT_DECOMPOSITION_TABLE(一个 Aten op 映射表)中查找 op 对应的函数实现并返回,若未实现则进入 b;

  2. 若不是则调用decompose()函数继续进行拆解,decompose()实现逻辑如下:


# decompose()函数实现# path:/torch/_ops.pydef decompose(self, *args, **kwargs):    dk = torch._C.DispatchKey.CompositeImplicitAutograd    if dk in self.py_kernels:        # NB: This branch is not too necessary anymore, because we can        # apply Python CompositeImplicitAutograd *before* tracing        # using Python dispatcher (also taking advantage of the autograd        # formula).  But it's included for completeness        return self.py_kernels[dk](*args, **kwargs)    elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk):        return self._op_dk(dk, *args, **kwargs)    else:        return NotImplemented
复制代码


从而实现 high level op 一步步拆解到 Aten op 的过程。总的来说,PrimTorch 是一种规定,将所有的 op 拆解为一个约定的 op 规范集合,并作为开发者和硬件厂商之间的一种中间桥梁,Pytorch 前端将 op 拆解映射到 PrimTorch,而硬件厂商针对这些特定的 op 进行优化即可


2 TorchInductor

TorchInductor 是 PyTorch 的一个高性能编译后端,专注于将优化后的计算图转换为高效的、针对特定硬件(如 CPU、GPU)的内核代码。它利用多种优化技术,包括内存优化、并行化和低层次的代码生成,以最大化计算性能。aot_dispatch_autograd()函数在拿到前反向的 FX Graph 后,分别调用fw_compilerbw_compiler对前反向图进行编译,这里的fw_compilerbw_compiler可以是不同的 compiler(npu 就是自定义的),在 inductor 的默认实现中调用的是compile_fx_inner,而其中的核心函数是fx_codegen_and_compile(),负责对 FX Graph 进行图优化、Triton 内核代码生成等

2.1 TorchInductor 函数调用

compile_fx_inner的核心实现逻辑如下。



fx_codegen_and_compile()中比较重要的三个函数是:


  1. _recursive_post_grad_passes:负责对计算图进一步的优化,包括:

  2. group_batch_fusion_passes:对 batch_linear、batch_relu、batch_sigmoid 等归一化操作进行算子融合,根据融合规则,然后以 BFS 的方式查找符合规则的 op 进行融合。

  3. remove_noop_ops:移除图形中本质上是 aten.clone 和 aten.alias 的操作。

  4. fuse_ddp_communication:对 ddp 通信的部分操作调用合并函数进行融合。

  5. decompose_auto_functionalized:对 high-level op 进一步进行拆解(因为前面进行算子融合那些操作可能会引入新的 high level op 所以这里再操作一遍),将高层次的操作逐步转换为更低层次的实现。

  6. GraphLowering:把 FX Graph 进一步降为 Inductor IR,即前面的计算图被进一步转换为低层次的中间表示。这一表示更加接近最终的机器代码,并且适合进一步的代码生成和优化。

  7. GraphLowering.compile_to_fn():负责对前面生成的 IR 表示转换为针对目标硬件低层次代码,GPU 上会生成 Triton,CPU 上会生成 OpenMP/C++,npu 上是?后续解答,同时可能会利用 SIMD 指令和多线程并行化来加速计算,是 inductor 中一个核心的实现。

2.2 compile_to_fn()——内核代码生成

compile_to_fn()在 Scheduler 类中实现内核代码编译的核心功能。而 Scheduler 的两个函数值得关注:


  1. Scheduler.init():实现算子融合等优化,基本流程为:

  2. compute_dependencies():分析 op 之间的依赖关系;

  3. fuse_nodes():合并节点,核心逻辑是通过get_possible_fusions获取可融合算子组合(这里只是先选出可融合的,因为可能 op 之间有交集,所以并未直接执行融合,而是筛出可融合的组合并排序再进行按序融合),然后再调用can_fuse()进一步检查是否可融合,最后进行融合,其中两个重要的函数是can_fuse()检查两个 op 融合是否合法,score_fusion()对给定的融合 op 排一个优先级(当融合 op 组合冲突时以排序分数高的先融合,排序得分基于<1>节省的内存操作的估计,<2> 尽量保持原始操作顺序);

  4. Scheduler.codegen():

  5. codegen_extern_call():是对部分 kernel 决策进行就地更改并记录决策(没看明白什么操作)

  6. self.get_backend(device).codegen_node(node):根据 device 调用 codegen_node 生成针对目标硬件的内核代码,如在/usr/local/lib/python3.9/dist-packages/torch/_inductor/codegen/cuda_combined_scheduling.py::codegen_node中实现了生成 Triton 内核代码。


回到compile_to_module(),将前面生成内核代码以.py 文件方式(triton 实现)保存到PyCodeCache中,最后调用PyCodeCache.load_by_key_path()获得编译后的 module(这个 module 包含 triton 代码的临时文件路径),返回到fx_codegen_and_compile()函数中将进一步封装成CompiledFxGraph


最后回到compile_fx_inner函数中,若支持 cudagraph 还会对编译后的图进行 cudagraph 编译优化(torch.compile 的 recude-overhead 模式下会自动添加 CUDA Graph 来减小运行时开销)。读到这里,以为 aclgraph 中的 recude-overhead 是专为 npu 添加,原来是来自于这里。


cudagraph 优化具体流程为:


  1. has_incompatible_cudagraph_ops():检查是否存在与 cudagraph 不兼容的 op

  2. cudagraphify():将子图转为 cudagraph 进行优化


到此完成 TorchInductor 的编译部分,返回一个 Triton 内核代码实现的 CompiledFxGraph,最后一路返回到compile_fx()即 inductor 的入口,又回到了调用call_user_compiler处,继续进行后续的操作(Ascend的aclgraph(三)TorchDynamo)完成整个流程。


# Inductor核心函数实现def fx_codegen_and_compile(    gm: torch.fx.GraphModule,    example_inputs: List[torch.Tensor],    cudagraphs: Optional[BoxedBool] = None,    static_input_idxs: Optional[List[int]] = None,    is_backward: bool = False,    graph_id: Optional[int] = None,    cpp_wrapper: bool = False,    aot_mode: bool = False,    is_inference: bool = False,    # Use a dict with None value rather than a set for deterministic    # iteration order just in case.    user_visible_outputs: Optional[Dict[str, None]] = None,    layout_opt: Optional[bool] = None,    extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None,) -> Union[CompiledFxGraph, str]:    # 省略中间...    V.debug.fx_graph(gm, example_inputs)    shape_env = _shape_env_from_inputs(example_inputs)    view_to_reshape(gm)
with torch.no_grad(): fake_mode = fake_tensor_prop(gm, example_inputs)
with V.set_fake_mode(fake_mode): # has some issues with memory in training _recursive_post_grad_passes(gm, is_inference=is_inference) # 优化计算图,包括group_batch_fusion、remove_noop_ops(拷贝别名处理)、fuse_ddp_communication等 V.debug.fx_graph_transformed(gm, example_inputs) post_grad_graphs_log.debug( "%s", lazy_format_graph_code( "AFTER POST GRAD", gm, include_stride=True, include_device=True ), ) trace_structured( "inductor_post_grad_graph", payload_fn=lambda: gm.print_readable( print_output=False, include_stride=True, include_device=True ), ) if config.is_fbcode(): log_optimus_to_scuba( extra_logging={"pt2_configs": str(get_patched_config_dict())} )
with V.set_fake_mode(fake_mode), maybe_disable_comprehensive_padding( example_inputs ): const_output_index = None const_graph = None const_code = None
if aot_mode and config.aot_inductor.use_runtime_constant_folding: const_gm, const_output_index = split_const_gm(gm)
const_graph = GraphLowering( const_gm, example_inputs=[], shape_env=shape_env, graph_id=graph_id, cpp_wrapper=cpp_wrapper, aot_mode=aot_mode, user_visible_outputs=user_visible_outputs, extern_node_serializer=extern_node_serializer, is_inference=is_inference, is_const_graph=True, ) with V.set_graph_handler(const_graph): assert cpp_wrapper, "AOT mode only supports C++ wrapper" const_graph.run()
const_code, _ = const_graph.codegen_with_cpp_wrapper() # 降为inductor IR以进一步的优化 graph = GraphLowering( gm, # example_inputs will be used by AOTInductor to dry-run the generated code for Triton kernel tuning. # For the forward pass, we have the real inputs to be used as example_inputs. For the backward pass, # we currently use fake tensors and defake them later. example_inputs=example_inputs, shape_env=shape_env, graph_id=graph_id, cpp_wrapper=cpp_wrapper, aot_mode=aot_mode, user_visible_outputs=user_visible_outputs, extern_node_serializer=extern_node_serializer, is_inference=is_inference, const_output_index=const_output_index, const_code=const_code, const_module=const_graph, ) metrics_helper = metrics.CachedMetricsHelper() with V.set_graph_handler(graph): graph.run(*example_inputs) output_strides: List[Optional[Tuple[int, ...]]] = [] if graph.graph_outputs is not None: # We'll put the output strides in the compiled graph so we # can later return them to the caller via TracingContext for out in graph.graph_outputs: if ( hasattr(out, "layout") and len(free_unbacked_symbols(out.layout.stride)) == 0 ): output_strides.append( tuple( V.graph.sizevars.size_hint(s) for s in out.layout.stride ) ) else: output_strides.append(None)
_check_triton_bf16_support(graph) compiled_fn = graph.compile_to_fn() # 生成对应的后端内核代码,GPU为Triton,CPU为C++/OpenMP
# 省略中间代码...
# 将编译后代码封装成CompiledFxGraph并返回 compiled_graph = CompiledFxGraph( compiled_fn, graph, output_strides, V.graph.disable_cudagraphs_reason, metrics_helper.get_deltas(), )
return compiled_graph
复制代码


到此梳理完了torch.compile()函数的整体流程,解析了从 TorchDynamo 捕获计算图、再到 AOTAutograd 捕获前反向计算图并进行算子decompose、以及最后在 TorchInductor 中完成算子融合和 kernel 代码生成的实现逻辑,后续再对其中的部分实现细节进行深入分析。

3 小结

用 3 篇的内容介绍了 torch.compile 中的相关概念,后续接着回到 torchair 的代码中,继续分析 aclgraph 的成图逻辑,以及 npu 上的后端是怎么执行的。

用户头像

zjun

关注

还未添加个人签名 2020-03-06 加入

还未添加个人简介

评论

发布
暂无评论
Ascend的aclgraph(五)PrimTorch & TorchInductor_PyTorch_zjun_InfoQ写作社区