Triton-Lang 在 Transformer 优化加速中的实践 | 得物技术
一、前言
众所周知,英伟达(Nvidia)自 2006 年推出 CUDA 以来,经过近 20 年的发展,尤其是经历了以卷积为代表的深度学习和近两年以 Transformer 为基础的 LLM 的推动,CUDA 编程基本上成为了 GPU 编程的代名词。CUDA 作为 GPU 的编程语言,不仅使用户能充分发挥 Nvidia GPU 的高性能的并行计算能力,也逐渐构筑了一个包括硬件、驱动、开发库和编程技巧的完备生态链,从而使 CUDA 成为了人工智能、高性能计算和云计算中的核心依赖。
(图片来源:Triton-lang documentation )
Triton 是 OpenAI 推出的以 python 为编程语言基础,专门为深度学习研发和高性能计算而设计的编程语言和编译器,旨在简化和优化 GPU 编程的复杂操作,降低高性能优化的门槛。
在大模型推理优化领域,已有很多优秀的工作开始应用 Triton 编写高效算子,例如近期被众多大模型推理框架集成的 Attention 算子 FlashAttention、推理加速框架 lightllm、训练加速框架的 Unsloth 等。
Triton 的初期版本以 CUDA 为起点而开发,为没有 CUDA 基础的编程者提供快速编写高效 CUDA kernel 的方案,而随着迭代已逐渐支持其他芯片和编程工具,如 AMD 的 ROCm,并在继续支持其他的芯片,如 Intel 的 CPU。因而,除了简化高性能计算,同时 Triton 也在试图构建一个“CUDA-free”的更高层的 kernel 编写方案,打破“天下苦 CUDA 久矣”的局面,把复杂的对底层芯片的交互,交给其 IR 和底层的编译器。
综上,可以说 Triton 是起于 CUDA,又不止于 CUDA。几个词可以简单总结 Triton 的特点和发展方向:
门槛低
高效
多平台
二、GPU 基础
在学习 Triton 的编程设计前,还是需要了解 GPU 一些简单的基础架构知识和 GPU 编程的基础概念。
以下左图是引自 NVIDIA 经典 Ampere 架构的 GA100(A100)的 datasheet 的整体架构示意图,展现其所有 128 个 SMs(Streaming Multiprocessors)和各级缓存、HBM(高性能内存)和 NvLink(Nvidia 卡间互联)等;而右图是 A100 的单个 SM(Streaming MultiProcessor, 多核流处理器) 的结构。
(图片来源:Nvidia-ampere-architecture-whitepaper )
从硬件的角度来讲,
SP (Streaming Processor 线程处理器) 是 CUDA 编程模型的最基本单位。每个 SP 都有自己的 registers (寄存器) 和 local memory (局部内存, L0 cache)。寄存器和局部内存只能被自己访问,不同的线程处理器之间彼此独立。
由多个线程处理器 (SP) 和一块共享内存(shared memory, L1 cache)构成了一个 SM。多核处理器里边的多个 SP 互相并行,且互不影响。每个 SM 内都有自己的共享内存,shared memory 可以被线程块内所有线程访问。
从软件的角度来讲,
thread(线程):一个 CUDA 程序被分成多个 threads 执行。
block 或 thread block (线程块):多个 threads 群组成一个 block,同一个 block 中的 threads 可以同步,也可以通过 shared memory 传递数据。
grid(网格):多个 blocks 会再构成 grid。
warp:GPU 执行程序时的调度单位。
对应关系:
一个 SP 可以执行一个 thread。
CUDA 的 device 在执行任务时,会把任务分成一个个的 block 分配给 SM 执行, 而每个 block 又会以 warp 为单位执行(Nvidia 把 32 个 threads 组成一个 warp, warp 即是 SM 调度和运行的基本单元,所有 SP 执行同一指令,但每个 thread 使用各自的 data)。
一个 warp 需要占用一个 SM,多个 warps 则会轮流进入 SM 处理。
(图片来源:OpenAI official introduction )
将上述结构大致抽象成 3 个组成部分 DRAM, SRAM 和 ALU, 其中 DRAM 即各个 HBMs(即俗称的显存),SRAM 指各级缓存,ALU 即计算单元(GPU 中的 SM),而当用户优化 CUDA 代码时需要考虑:
DRAM 读写时的内存合并:以保证充分利用 GPU 的内存带宽;
数据必须手动分配至各级 SRAM:以尽可能地避免共享内存冲突;
计算流程必须在 SM 内部和外部谨慎合理地设计、分配和调度:以促进并行线程的计算效率。
而在编程设计时充分考虑以上,即使是对于富有经验的 CUDA 编程者也颇具挑战,因而 Triton 希望底层编译器对多数的调度细节能自动优化,而用户只需要考虑一些顶层的逻辑设计,即 SMs 层级的,例如矩阵分片,SM 之间数据同步等问题。
其官网介绍给出了一个对比,
(表格来源:OpenAI official introduction)
通俗而言,相比于 CUDA,使用 Triton,你不必控制所有内容,因为有些事情可以留给工具自动优化;用 Triton 编写的模块可能不一定优于顶级的 CUDA 算子,但是性能通常能优于普通的 CUDA kernel;而前者的门槛大大低于后者。
因而 Triton 的编程设计过程,其关键在于 SM 层级的并行处理过程的设计,即画好 SM 层级的网格图以表示算子的计算过程。
三、Triton 编程实例
向量求和
内核函数
向量求和对于 Triton 是一个"Hello World"式的示例。使用 Pytorch,对于两个同长度的 vector,直接相加,非常简单。
而对于 Triton,需要编写一个内核函数(kernel)和一个调用函数(wrapper),调用时的并行网格图如下:
kernel 函数代码如下:
@triton.jit 装饰器用于定义内核函数,在程序执行时即时编译并在 GPU 上执行。
x_ptr, y_ptr, output_ptr 分别是两个输入向量和一个输出向量的指针,n_elements 表示向量长度,BLOCK_SIZE 的数据类型为 tl.constexpr,表示一个编译时的常量,定义了每个线程块处理数据时的数据长度。
向量相加虽然简单,但是基本体现了内核函数通常的编写流程,定义维度 -> 计算偏置 -> 设置掩码 -> 读取数据 -> 计算过程 -> 写回数据。
定义维度:当前程序(线程块)通过 tl.program_id 获取自己的 pid, 该程序 id 标识了当前程序的唯一性。tl.program_id 和块大小(BLOCK_SIZE)也决定了并行处理时对整个数据块的划分,比如在这个向量数据的处理时,axis=0 表示一维的划分,再比如矩阵乘法的操作,当我们用分块矩阵的思路设计内核时,则是在二维层面的操作。
计算偏置:得到当前程序的 id 时,我们需要从整个数据块拿取当前程序所需的那块数据,所以需要通过 id 和块大小(BLOCK_SIZE)计算 offsets。需要注意的是,这里的 offsets 是一个 list,即是当前需要的数据的所有索引。
设置掩码:因为数据的长度通常无法被我们预设的块大小整除,比如下图示例中的最后一块,所以需要设置 mask,防止内存操作超出范围。
读取数据:根据输入数据的指针、偏置和掩码,从 DRAM(显存) 读取数据到当前程序所在的 SRAM(缓存)。
计算过程:在这里定义我们所需要的计算流程,例如将两段数据 x 和 y 相加。
写回数据:处理完数据后,同样根据输出数据的指针、偏置和掩码,把结果 output 从 SRAM 写回 DRAM。
线程块在 GPU 的计算模型里又被称为 CTA(Cooperative Thread Array),以上的计算过程相当于一个 CTA 处理单个 block。
而当缓存受限时,我们也可以在单个 CTA 中处理多个 blocks, 如下图和相应的写法:
接口函数
有了内核函数,我们需要再写一个 wrapper,就可以调用内核(好比 Pytorch 的 torch.Add api, 即加号"+")。
这里需要注意两点:
内核程序的运行需要启动一个网格,Triton 以 SPMD(单程序多数据,与 SIMD 类似)的方式执行程序。网格 grid 与内核函数中一开始我们获取的程序标识(id)相对应,在向量处理这个示例中,它是一个一维的网格,数据格式可以是 Callable(metaparameters) -> Tuple[int] ,如上面代码(triton.cdiv 是 Triton 封装的除法,cdiv 表示 ceiling division),也可以直接是 Tuple[int],如(n_elements + BLOCK_SIZE-1)//BLOCK_SIZE。
我们看上述调用内核函数的格式,可以看到,内核函数可以被 grid 索引,每次索引可以得到一个 GPU 内核,启动一个程序。x,y,output 这些张量作为参数传入内核函数的同时,被隐式地转化为指向各自张量第一个元素的指针。
性能测试
Triton 自带性能测试函数,可以帮助衡量自己设计的算子与 baseline 之间的差距。装饰器 @triton.testing.perf_report 用于装饰 benchmark 函数,而 triton.testing.Benchmark 函数定义了 plot 折线图的属性,在 benchmark 函数里面我们可以定义指标来比较不同算子之间的性能,如 Triton 和 Pytorch 算子之间在不同 size 计算下的吞吐差距。
在左图,我们可以看到在向量维度较小时,torch 的计算更快,而当维度较大时,Triton 算子的计算更快一些。
相比向量的计算可能无法体现自定义算子的优势,右图展示了 Triton 官方教程中定义的矩阵乘法算子的性能,可以看到其和 cuBLAS 编写的算子相比已能够达到持平的性能。
Vector Add Triton vs. Pytorch
Matrix multiplication Triton vs cuBLAS
调试
早期的 Triton 核函数开发不支持调试,目前已经支持 pdb 和各种 python ide 的断点调试,只需设置环境变量即可。
矩阵乘法
两个矩阵相乘,
在 Pytorch 中,两个矩阵相乘可以直接以 torch.matmul(A, B)计算得到。而进一步对其稍作优化,我们立刻能想到的通常是分块矩阵。用 Pytorch 表示具体的流程:
A, B 表示两个输入矩阵,其二维尺寸分别为(M, K)和(K, N); BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K 分别为分块时 M, N, K 三个维度的分块尺寸。
主程序用 3 层循环计算,外边 2 层循环分别是依次遍历 A 和 B 两个矩阵的列和行;而最里边的循环,则是对于输出矩阵,每个对应的分块(M, N), 需要 A 和 B 相对应的列和行的分块依次相乘后累加。
从以上逻辑可以看出,这是一个行主序(Row Major)的分块矩阵乘法顺序。
我们可以画出其网格图,以下橙色的 blocks 表示单个 CTA 的计算过程。
按照网格图,将上述的计算过程改写为 Triton 的内核函数:
以上代码,我们关注几点:
网格:采用二维网格,作为程序标识,并计算其输入分块矩阵和输出分块矩阵的偏置,以及分块矩阵的掩码矩阵;
累加:由上图可知,输出矩阵的每个分块,分别由其对应的 K//BLOCK_K 个矩阵相乘的结果累加得到。
行主序和列主序的代码和计算顺序如下,虽说 CUDA 是并行计算的程序,但是当我们将矩阵分为很多的程序执行时,如果我们的 GPU 并没有足够的 SM 来同时执行所有程序,因而这些程序是先后被加载入 SM 计算的。而 CUDA 默认是以列主序存储数据的,所以有时候列主序的程序性能要优于行主序。
行主序:
列主序:
而 Triton 的官网给出了一个基于 L2 cache 优化的方案。其思路是以减少访存次数来提高 cache 的命中率。我们可以从下图比较其与通常的乘法算子的区别。
通常的列主序(column-major-ordering)
分组后的列主序 (grouped-column-major-ordering)
从上可以看出,左边计算 4 个 CTA 时,需要读取 1 列和 4 行,总共要进行 5 次读取;而右边的操作,只需要读取 2 列和 2 行,共 4 次读取。实际计算中,矩阵的行列维度数值都较大,分组后的计算在访存上会有一定的优化,而实际中在例如 A100 的硬件上这样的优化也能有 10%的提升。
以下是官网优化示例给出的核心代码,相比于上述的二维索引,引入 group 之后采用一维索引,而代码的本质则是将这个一维索引 pid 转化为二维索引(pid_m, pid_n),而在这个变化中,我们重新定义了结果矩阵的计算顺序(即上图,右图中区别于左图的元素计算顺序)。
我们用两个 9x9 的矩阵相乘来说明这个索引的过程:
首先 num_pid_m,num_pid_n 分别计算了两个矩阵的 M, N 维度各自块的个数;如下图中 num_pid_m=9,num_pid_n=9。
GROUP_SIZE_M 定义了 group 的维度;如下图中的 GROUP_SIZE_M=3。
num_pid_in_group 为单个 group 中块的个数;如下图中 GROUP_SIZE_M * num_pid_n=27。
group_id 则计算得到 pid 是在哪个 group 中,即 group 的 id;如下图中 pid // num_pid_in_group=1。
first_pid_m 计算的是这个 pid 所在的 group 的第一个块的 pid_m 的值,以方便后续为算 pid 最终的 pid_m, pid_n 提供偏置;如下图中 first_pid_m=3 ,即图中第 27 块的 pid_m。
group_size_m 则是计算了这个 pid 所在的 group 的行数,这是为了避免 M 无法为 GROUP_SIZE_M 所整除时,最后一个 group 的行数小于 GROUP_SIZE_M;如下图的 group 的行数值为 GROUP_SIZE_M,若当图中的行数为 8 时,可以想像最后一行的 group_size_m 为 2。
最后两行代码则是计算 pid 的真正坐标 (pid_m, pid_n)。例如下图的 pid=33, 则 pid_m=3+(33%3)=3, pid_n=(33%27)//3=2。
旋转位置编码
旋转位置编码(RoPE, Rotate Position Encoding)是 Transformer 进入大模型应用时代后的重要算子,在 Llama,ChatGLM 等主流的大模型中都有应用。关于旋转位置编码的原理和作用可以参考原论文和作者博客。其计算过程可以简要表示成以下的旋转变换,
以下是一个 Huggingface 中 Llama RoPE 的前向计算流程。
d 表示 embedding 的维度,则位置编码的相位频率表示如下,m = [0, 2, 4, ..., d/2] , f = (1/10000)^m,
n 表示 token 的个数,
以上是计算 cos 和 sin 两个旋转变换矩阵的过程,而矩阵 q 和 k 在做注意力乘法前先做简单处理。
最后是旋转变换:
我们来看 Triton kernel 对前向过程的实现:
给定矩阵 Q,Cos,Sin, 它们的维度分别为
h 是注意力模块中的 head 数,为方便说明,可令 b=1;需要实现的计算过程是
。
为计算方便,现将 Q reshape 为(n,hd)。
考虑并行的维度时,按 token 并行是首先能想到的,n_rows = n,再考虑到有限的缓存,可以将另一个维度按 group 分,可以定义一个常量 GROUP_SIZE=4(当然这里也可以设计成 autotune 模式,自动选择合适值),然后可以计算得到 embedding 维度的 group 数量 n_groups ,并定义我们的网格 grid。
下图阐释了整体和单元的计算过程。并行程序按两个维度计算,每个 CTA 的计算过程中有个次数为 GROUP_SIZE 的循环累加,各自累加计算得到 q*cos 和 rotate_half(q)*sin,再相加。
四、总结
参考 Triton 的官方示例、以及其他社区的开源加速工具,我们还能看到其他许多算子,诸如 rmsNorm, Softmax, Flash-Attention 等的具体加速方案和并行思路,以及他们的反向传导过程。而这些算子则组成了 Transformer 的注意力模块的整体结构。通过对各个模块的并行优化,可以实现对注意力计算的推理和训练加速。
Transformer 的注意力结构及相关算子
在实际应用中,我们可以自己编写和优化自定义算子,也可以引用社区优秀的开源加速算子库,而且 Pytorch 在 2.0+版本后已将 Triton 集成到其编译器,用 torch.compile()可以直接对已加载模型编译,帮助自动优化可优化的算子。
参考:
Triton-lang 介绍 (https://openai.com/index/triton/)
Triton-tutorial(https://triton-lang.org/main/getting-started/tutorials/index.html)
Nvidia GA100 datasheet (https://images.nvidia.cn/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf)
CUDA programming(https://docs.nvidia.com/cuda/cuda-c-programming-guide/#kernels)
文 / xujiong
关注得物技术,每周更新技术干货
要是觉得文章对你有帮助的话,欢迎评论转发点赞~
未经得物技术许可严禁转载,否则依法追究法律责任。
评论