Ascend 的 aclgraph(五)PrimTorch & TorchInductor

1 PrimTorch
参考:Torch.compile()流程解析——4. PrimTorch & TorchInductor
在上一篇Ascend的aclgraph(四)AOT Autograd构建 joint graph 的时候提及过 op 执行的时候,通过ProxyTorchDispatchMode
的torch_dispatch
对 op 进行decompose
,具体流程是:
调用
maybe_handle_decomp
()函数在CURRENT_DECOMPOSITION_TABLE
(一个 Aten op 映射表)中查找 op 对应的函数实现并返回,若未实现则进入 b;若不是则调用
decompose
()函数继续进行拆解,decompose()实现逻辑如下:
从而实现 high level op 一步步拆解到 Aten op 的过程。总的来说,PrimTorch 是一种规定,将所有的 op 拆解为一个约定的 op 规范集合,并作为开发者和硬件厂商之间的一种中间桥梁,Pytorch 前端将 op 拆解映射到 PrimTorch,而硬件厂商针对这些特定的 op 进行优化即可。

2 TorchInductor
TorchInductor
是 PyTorch 的一个高性能编译后端,专注于将优化后的计算图转换为高效的、针对特定硬件(如 CPU、GPU)的内核代码。它利用多种优化技术,包括内存优化、并行化和低层次的代码生成,以最大化计算性能。aot_dispatch_autograd
()函数在拿到前反向的 FX Graph 后,分别调用fw_compiler
、bw_compiler
对前反向图进行编译,这里的fw_compiler
和bw_compiler
可以是不同的 compiler(npu 就是自定义的),在 inductor 的默认实现中调用的是compile_fx_inner
,而其中的核心函数是fx_codegen_and_compile
(),负责对 FX Graph 进行图优化、Triton 内核代码生成等。
2.1 TorchInductor 函数调用
compile_fx_inner
的核心实现逻辑如下。

fx_codegen_and_compile
()中比较重要的三个函数是:
_recursive_post_grad_passes
:负责对计算图进一步的优化,包括:group_batch_fusion_passes
:对 batch_linear、batch_relu、batch_sigmoid 等归一化操作进行算子融合,根据融合规则,然后以 BFS 的方式查找符合规则的 op 进行融合。remove_noop_ops
:移除图形中本质上是 aten.clone 和 aten.alias 的操作。fuse_ddp_communication
:对 ddp 通信的部分操作调用合并函数进行融合。decompose_auto_functionalized
:对 high-level op 进一步进行拆解(因为前面进行算子融合那些操作可能会引入新的 high level op 所以这里再操作一遍),将高层次的操作逐步转换为更低层次的实现。GraphLowering
:把 FX Graph 进一步降为 Inductor IR,即前面的计算图被进一步转换为低层次的中间表示。这一表示更加接近最终的机器代码,并且适合进一步的代码生成和优化。GraphLowering.compile_to_fn
():负责对前面生成的 IR 表示转换为针对目标硬件低层次代码,GPU 上会生成 Triton,CPU 上会生成 OpenMP/C++,npu 上是?后续解答,同时可能会利用 SIMD 指令和多线程并行化来加速计算,是 inductor 中一个核心的实现。
2.2 compile_to_fn()——内核代码生成
compile_to_fn
()在 Scheduler 类中实现内核代码编译的核心功能。而 Scheduler 的两个函数值得关注:
Scheduler.init
():实现算子融合等优化,基本流程为:compute_dependencies
():分析 op 之间的依赖关系;fuse_nodes
():合并节点,核心逻辑是通过get_possible_fusions
获取可融合算子组合(这里只是先选出可融合的,因为可能 op 之间有交集,所以并未直接执行融合,而是筛出可融合的组合并排序再进行按序融合),然后再调用can_fuse
()进一步检查是否可融合,最后进行融合,其中两个重要的函数是can_fuse
()检查两个 op 融合是否合法,score_fusion
()对给定的融合 op 排一个优先级(当融合 op 组合冲突时以排序分数高的先融合,排序得分基于<1>节省的内存操作的估计,<2> 尽量保持原始操作顺序);Scheduler.codegen
():codegen_extern_call
():是对部分 kernel 决策进行就地更改并记录决策(没看明白什么操作)self.get_backend(device).codegen_node(node)
:根据 device 调用 codegen_node 生成针对目标硬件的内核代码,如在/usr/local/lib/python3.9/dist-packages/torch/_inductor/codegen/cuda_combined_scheduling.py::codegen_node
中实现了生成 Triton 内核代码。
回到compile_to_module
(),将前面生成内核代码以.py 文件方式(triton 实现)保存到PyCodeCache
中,最后调用PyCodeCache.load_by_key_path
()获得编译后的 module(这个 module 包含 triton 代码的临时文件路径),返回到fx_codegen_and_compile
()函数中将进一步封装成CompiledFxGraph
。
最后回到compile_fx_inner
函数中,若支持 cudagraph 还会对编译后的图进行 cudagraph 编译优化(torch.compile 的 recude-overhead 模式下会自动添加 CUDA Graph 来减小运行时开销)。读到这里,以为 aclgraph 中的 recude-overhead 是专为 npu 添加,原来是来自于这里。
cudagraph 优化具体流程为:
has_incompatible_cudagraph_ops
():检查是否存在与 cudagraph 不兼容的 opcudagraphify
():将子图转为 cudagraph 进行优化
到此完成 TorchInductor 的编译部分,返回一个 Triton 内核代码实现的 CompiledFxGraph,最后一路返回到compile_fx
()即 inductor 的入口,又回到了调用call_user_compiler
处,继续进行后续的操作(Ascend的aclgraph(三)TorchDynamo)完成整个流程。
到此梳理完了torch.compile
()函数的整体流程,解析了从 TorchDynamo 捕获计算图、再到 AOTAutograd 捕获前反向计算图并进行算子decompose
、以及最后在 TorchInductor 中完成算子融合和 kernel 代码生成的实现逻辑,后续再对其中的部分实现细节进行深入分析。
3 小结
用 3 篇的内容介绍了 torch.compile 中的相关概念,后续接着回到 torchair 的代码中,继续分析 aclgraph 的成图逻辑,以及 npu 上的后端是怎么执行的。
评论