写点什么

Triton-Lang 在 Transformer 优化加速中的实践 | 得物技术

作者:得物技术
  • 2025-01-14
    上海
  • 本文字数:9712 字

    阅读完需:约 32 分钟

Triton-Lang在Transformer优化加速中的实践 | 得物技术

一、前言

众所周知,英伟达(Nvidia)自 2006 年推出 CUDA 以来,经过近 20 年的发展,尤其是经历了以卷积为代表的深度学习和近两年以 Transformer 为基础的 LLM 的推动,CUDA 编程基本上成为了 GPU 编程的代名词。CUDA 作为 GPU 的编程语言,不仅使用户能充分发挥 Nvidia GPU 的高性能的并行计算能力,也逐渐构筑了一个包括硬件、驱动、开发库和编程技巧的完备生态链,从而使 CUDA 成为了人工智能、高性能计算和云计算中的核心依赖。



(图片来源:Triton-lang documentation )


Triton 是 OpenAI 推出的以 python 为编程语言基础,专门为深度学习研发和高性能计算而设计的编程语言和编译器,旨在简化和优化 GPU 编程的复杂操作,降低高性能优化的门槛。


在大模型推理优化领域,已有很多优秀的工作开始应用 Triton 编写高效算子,例如近期被众多大模型推理框架集成的 Attention 算子 FlashAttention、推理加速框架 lightllm、训练加速框架的 Unsloth 等。


Triton 的初期版本以 CUDA 为起点而开发,为没有 CUDA 基础的编程者提供快速编写高效 CUDA kernel 的方案,而随着迭代已逐渐支持其他芯片和编程工具,如 AMD 的 ROCm,并在继续支持其他的芯片,如 Intel 的 CPU。因而,除了简化高性能计算,同时 Triton 也在试图构建一个“CUDA-free”的更高层的 kernel 编写方案,打破“天下苦 CUDA 久矣”的局面,把复杂的对底层芯片的交互,交给其 IR 和底层的编译器。


综上,可以说 Triton 是起于 CUDA,又不止于 CUDA。几个词可以简单总结 Triton 的特点和发展方向:


  • 门槛低

  • 高效

  • 多平台

二、GPU 基础

在学习 Triton 的编程设计前,还是需要了解 GPU 一些简单的基础架构知识和 GPU 编程的基础概念。


以下左图是引自 NVIDIA 经典 Ampere 架构的 GA100(A100)的 datasheet 的整体架构示意图,展现其所有 128 个 SMs(Streaming Multiprocessors)和各级缓存、HBM(高性能内存)和 NvLink(Nvidia 卡间互联)等;而右图是 A100 的单个 SM(Streaming MultiProcessor, 多核流处理器) 的结构。




(图片来源:Nvidia-ampere-architecture-whitepaper )


从硬件的角度来讲,


  • SP (Streaming Processor 线程处理器) 是 CUDA 编程模型的最基本单位。每个 SP 都有自己的 registers (寄存器) 和 local memory (局部内存, L0 cache)。寄存器和局部内存只能被自己访问,不同的线程处理器之间彼此独立。

  • 由多个线程处理器 (SP) 和一块共享内存(shared memory, L1 cache)构成了一个 SM。多核处理器里边的多个 SP 互相并行,且互不影响。每个 SM 内都有自己的共享内存,shared memory 可以被线程块内所有线程访问。


从软件的角度来讲,


  • thread(线程):一个 CUDA 程序被分成多个 threads 执行。

  • block 或 thread block (线程块):多个 threads 群组成一个 block,同一个 block 中的 threads 可以同步,也可以通过 shared memory 传递数据。

  • grid(网格):多个 blocks 会再构成 grid。

  • warp:GPU 执行程序时的调度单位。


对应关系:


  • 一个 SP 可以执行一个 thread。

  • CUDA 的 device 在执行任务时,会把任务分成一个个的 block 分配给 SM 执行, 而每个 block 又会以 warp 为单位执行(Nvidia 把 32 个 threads 组成一个 warp, warp 即是 SM 调度和运行的基本单元,所有 SP 执行同一指令,但每个 thread 使用各自的 data)。

  • 一个 warp 需要占用一个 SM,多个 warps 则会轮流进入 SM 处理。



(图片来源:OpenAI official introduction )


将上述结构大致抽象成 3 个组成部分 DRAM, SRAM 和 ALU, 其中 DRAM 即各个 HBMs(即俗称的显存),SRAM 指各级缓存,ALU 即计算单元(GPU 中的 SM),而当用户优化 CUDA 代码时需要考虑:


  • DRAM 读写时的内存合并:以保证充分利用 GPU 的内存带宽;

  • 数据必须手动分配至各级 SRAM:以尽可能地避免共享内存冲突;

  • 计算流程必须在 SM 内部和外部谨慎合理地设计、分配和调度:以促进并行线程的计算效率。


而在编程设计时充分考虑以上,即使是对于富有经验的 CUDA 编程者也颇具挑战,因而 Triton 希望底层编译器对多数的调度细节能自动优化,而用户只需要考虑一些顶层的逻辑设计,即 SMs 层级的,例如矩阵分片,SM 之间数据同步等问题。


其官网介绍给出了一个对比,



(表格来源:OpenAI official introduction)


通俗而言,相比于 CUDA,使用 Triton,你不必控制所有内容,因为有些事情可以留给工具自动优化;用 Triton 编写的模块可能不一定优于顶级的 CUDA 算子,但是性能通常能优于普通的 CUDA kernel;而前者的门槛大大低于后者。


因而 Triton 的编程设计过程,其关键在于 SM 层级的并行处理过程的设计,即画好 SM 层级的网格图以表示算子的计算过程。

三、Triton 编程实例

向量求和

内核函数

向量求和对于 Triton 是一个"Hello World"式的示例。使用 Pytorch,对于两个同长度的 vector,直接相加,非常简单。



size = 1024x = torch.rand(size, device='cuda')y = torch.rand(size, device='cuda')output_torch = x + y
复制代码


而对于 Triton,需要编写一个内核函数(kernel)和一个调用函数(wrapper),调用时的并行网格图如下:



kernel 函数代码如下:


import triton.language as tl
@triton.jitdef add_kernel(x_ptr, # 第一个输入向量的指针 y_ptr, # 第二个输入向量的指针 output_ptr, # 输出向量的指针 n_elements, # 向量长度 BLOCK_SIZE: tl.constexpr, # 每个线程块处理的元素数量 ): # 有多个'程序'处理不同的数据, 用pid标识当前是哪个程序 pid = tl.program_id(axis=0) # 计算当前程序所需要的数据的偏置 block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) # 创建一个掩码以防止内存操作超出范围 mask = offsets < n_elements # 从 DRAM 加载 x 和 y x = tl.load(x_ptr + offsets, mask=mask) y = tl.load(y_ptr + offsets, mask=mask) output = x + y # 将计算结果output写回 DRAM tl.store(output_ptr + offsets, output, mask=mask)
复制代码


  • @triton.jit 装饰器用于定义内核函数,在程序执行时即时编译并在 GPU 上执行。

  • x_ptr, y_ptr, output_ptr 分别是两个输入向量和一个输出向量的指针,n_elements 表示向量长度,BLOCK_SIZE 的数据类型为 tl.constexpr,表示一个编译时的常量,定义了每个线程块处理数据时的数据长度。

  • 向量相加虽然简单,但是基本体现了内核函数通常的编写流程,定义维度 -> 计算偏置 -> 设置掩码 -> 读取数据 -> 计算过程 -> 写回数据。

  • 定义维度:当前程序(线程块)通过 tl.program_id 获取自己的 pid, 该程序 id 标识了当前程序的唯一性。tl.program_id 和块大小(BLOCK_SIZE)也决定了并行处理时对整个数据块的划分,比如在这个向量数据的处理时,axis=0 表示一维的划分,再比如矩阵乘法的操作,当我们用分块矩阵的思路设计内核时,则是在二维层面的操作。

  • 计算偏置:得到当前程序的 id 时,我们需要从整个数据块拿取当前程序所需的那块数据,所以需要通过 id 和块大小(BLOCK_SIZE)计算 offsets。需要注意的是,这里的 offsets 是一个 list,即是当前需要的数据的所有索引。

  • 设置掩码:因为数据的长度通常无法被我们预设的块大小整除,比如下图示例中的最后一块,所以需要设置 mask,防止内存操作超出范围。

  • 读取数据:根据输入数据的指针、偏置和掩码,从 DRAM(显存) 读取数据到当前程序所在的 SRAM(缓存)。

  • 计算过程:在这里定义我们所需要的计算流程,例如将两段数据 x 和 y 相加。

  • 写回数据:处理完数据后,同样根据输出数据的指针、偏置和掩码,把结果 output 从 SRAM 写回 DRAM。


线程块在 GPU 的计算模型里又被称为 CTA(Cooperative Thread Array),以上的计算过程相当于一个 CTA 处理单个 block。



而当缓存受限时,我们也可以在单个 CTA 中处理多个 blocks, 如下图和相应的写法:



@triton.jitdef add_kernel(x_ptr, y_ptr, o_ptr, n_elements, num_blocks_per_CTA, BLOCK_SIZE: tl.constexpr,):    pid = tl.program_id(axis=0)      program_offsets = pid * num_blocks_per_CTA * BLOCK_SIZE     offsets = program_offsets + tl.arange(0, BLOCK_SIZE)        for i in range(num_blocks_per_CTA):        mask = offsets < n_elements        x = tl.load(x_ptr + offsets, mask=mask)        y = tl.load(y_ptr + offsets, mask=mask)        output = x + y        tl.store(o_ptr + offsets, output, mask=mask)        offsets += BLOCK_SIZE
复制代码

接口函数

有了内核函数,我们需要再写一个 wrapper,就可以调用内核(好比 Pytorch 的 torch.Add api, 即加号"+")。


def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:    output = torch.empty_like(x)    assert x.is_cuda and y.is_cuda and output.is_cuda    n_elements = output.numel()    # SPMD启动网格,表示并行运行的内核实例数。    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )    # 调用内核函数    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)    # 我们返回一个指向z的句柄,但是,由于`torch.cuda.synchronize()`尚未被调用,内核此时仍在异步运行。    return output
torch.manual_seed(0)size = 98432x = torch.rand(size, device='cuda')y = torch.rand(size, device='cuda')output_triton = add(x, y)
复制代码


这里需要注意两点:


  • 内核程序的运行需要启动一个网格,Triton 以 SPMD(单程序多数据,与 SIMD 类似)的方式执行程序。网格 grid 与内核函数中一开始我们获取的程序标识(id)相对应,在向量处理这个示例中,它是一个一维的网格,数据格式可以是 Callable(metaparameters) -> Tuple[int] ,如上面代码(triton.cdiv 是 Triton 封装的除法,cdiv 表示 ceiling division),也可以直接是 Tuple[int],如(n_elements + BLOCK_SIZE-1)//BLOCK_SIZE。

  • 我们看上述调用内核函数的格式,可以看到,内核函数可以被 grid 索引,每次索引可以得到一个 GPU 内核,启动一个程序。x,y,output 这些张量作为参数传入内核函数的同时,被隐式地转化为指向各自张量第一个元素的指针。

性能测试

Triton 自带性能测试函数,可以帮助衡量自己设计的算子与 baseline 之间的差距。装饰器 @triton.testing.perf_report 用于装饰 benchmark 函数,而 triton.testing.Benchmark 函数定义了 plot 折线图的属性,在 benchmark 函数里面我们可以定义指标来比较不同算子之间的性能,如 Triton 和 Pytorch 算子之间在不同 size 计算下的吞吐差距。


@triton.testing.perf_report(    triton.testing.Benchmark(        x_names=['size'],  # Argument names to use as an x-axis for the plot.        x_vals=[2**i for i in range(12, 28, 1)],  # Different possible values for `x_name`.        x_log=True,  # x axis is logarithmic.        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.        line_vals=['triton', 'torch'],  # Possible values for `line_arg`.        line_names=['Triton', 'Torch'],  # Label name for the lines.        styles=[('blue', '-'), ('green', '-')],  # Line styles.        ylabel='GB/s',  # Label name for the y-axis.        plot_name='vector-add-performance',  # Name for the plot. Used also as a file name for saving the plot.        args={},  # Values for function arguments not in `x_names` and `y_name`.    ))def benchmark(size, provider):    x = torch.rand(size, device='cuda', dtype=torch.float32)    y = torch.rand(size, device='cuda', dtype=torch.float32)    quantiles = [0.5, 0.2, 0.8]    if provider == 'torch':        ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)    if provider == 'triton':        ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles)    gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)    return gbps(ms), gbps(max_ms), gbps(min_ms)
benchmark.run(print_data=True, show_plots=True, save_path="./")
复制代码


在左图,我们可以看到在向量维度较小时,torch 的计算更快,而当维度较大时,Triton 算子的计算更快一些。


相比向量的计算可能无法体现自定义算子的优势,右图展示了 Triton 官方教程中定义的矩阵乘法算子的性能,可以看到其和 cuBLAS 编写的算子相比已能够达到持平的性能。



Vector Add Triton vs. Pytorch



Matrix multiplication Triton vs cuBLAS

调试

早期的 Triton 核函数开发不支持调试,目前已经支持 pdb 和各种 python ide 的断点调试,只需设置环境变量即可。


os.environ["TRITON_INTERPRET"]=1
复制代码

矩阵乘法

两个矩阵相乘,



在 Pytorch 中,两个矩阵相乘可以直接以 torch.matmul(A, B)计算得到。而进一步对其稍作优化,我们立刻能想到的通常是分块矩阵。用 Pytorch 表示具体的流程:


# Pytorchimport torchfrom typing import Tuple
@torch.jit.scriptdef block_matrix_multiplication(A: torch.Tensor, B:torch.Tensor, M: int, N: int, K: int, BLOCK_SIZE_M: int, BLOCK_SIZE_N: int, BLOCK_SIZE_K: int) -> torch.Tensor: C = torch.zeros((M, N), dtype=torch.float32) for m in range(0, M, BLOCK_SIZE_M): for n in range(0, N, BLOCK_SIZE_N): acc = torch.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=torch.float32) for k in range(0, K, BLOCK_SIZE_K): a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K] b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N] acc += torch.matmul(a, b) C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc return C
# 用法示例A = torch.rand(100, 100)B = torch.rand(100, 100)result = block_matrix_multiplication(A, B, 100, 100, 100, 16, 16, 16)
复制代码


  • A, B 表示两个输入矩阵,其二维尺寸分别为(M, K)和(K, N); BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K 分别为分块时 M, N, K 三个维度的分块尺寸。

  • 主程序用 3 层循环计算,外边 2 层循环分别是依次遍历 A 和 B 两个矩阵的列和行;而最里边的循环,则是对于输出矩阵,每个对应的分块(M, N), 需要 A 和 B 相对应的列和行的分块依次相乘后累加。


从以上逻辑可以看出,这是一个行主序(Row Major)的分块矩阵乘法顺序。


我们可以画出其网格图,以下橙色的 blocks 表示单个 CTA 的计算过程。



按照网格图,将上述的计算过程改写为 Triton 的内核函数:


# Triton kernel@triton.jitdet matmul_kernel(A, B, C, M, N, K,     BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):    # 2d grid    pid_m = tl.program_id(0)    pid_n = tl.program_id(1)    offsets_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)    mask_m = offsets_m < M    offsets_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)    mask_n = offsets_n < N        # 2d tile    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)    for start_k in range(0, K, BLOCK_K):        offsets_k = start_k + tl.arange(0, BLOCK_K)        mask_k = offsets_k < K                a_ptrs = A + offsets_m[:, None]*K + offsets_k[None, :]        mask_a = mask_m[:, None] & mask_k[None, :]        b_ptrs = B + offsets_k[:, None]*N + offsets_n[None, :]        mask_b = mask_k[:, None] & mask_n[None, :]            a = tl.load(a_ptrs, mask=mask_a, other=0)        b = tl.load(b_ptrs, mask=mask_b, other=0)        acc += tl.dot(a, b)    c_ptrs = C + offsets_m[:, None]*N + offsets_n    mask_c = mask_m[:, None] & mask_n[None, :]    tl.store(c_ptrs, acc, mask = mask_c)     # grid = (tl.cdiv(M, BLOCK_M), tl.cdiv(N, BLOCK_N), 1)
复制代码


以上代码,我们关注几点:


  • 网格:采用二维网格,作为程序标识,并计算其输入分块矩阵和输出分块矩阵的偏置,以及分块矩阵的掩码矩阵;

  • 累加:由上图可知,输出矩阵的每个分块,分别由其对应的 K//BLOCK_K 个矩阵相乘的结果累加得到。


行主序和列主序的代码和计算顺序如下,虽说 CUDA 是并行计算的程序,但是当我们将矩阵分为很多的程序执行时,如果我们的 GPU 并没有足够的 SM 来同时执行所有程序,因而这些程序是先后被加载入 SM 计算的。而 CUDA 默认是以列主序存储数据的,所以有时候列主序的程序性能要优于行主序。


  • 行主序:


for m in range(0, m, BLOCK_M):    for n in range(0, n, BLOCK_N):         CTA(m, n) ....
复制代码



  • 列主序:


for n in range(0, N, BLOCK_N):    for m in range(0, M, BLOCK_M):         CTA(m, n) ....
复制代码



而 Triton 的官网给出了一个基于 L2 cache 优化的方案。其思路是以减少访存次数来提高 cache 的命中率。我们可以从下图比较其与通常的乘法算子的区别。



通常的列主序(column-major-ordering)



分组后的列主序 (grouped-column-major-ordering)


从上可以看出,左边计算 4 个 CTA 时,需要读取 1 列和 4 行,总共要进行 5 次读取;而右边的操作,只需要读取 2 列和 2 行,共 4 次读取。实际计算中,矩阵的行列维度数值都较大,分组后的计算在访存上会有一定的优化,而实际中在例如 A100 的硬件上这样的优化也能有 10%的提升。


以下是官网优化示例给出的核心代码,相比于上述的二维索引,引入 group 之后采用一维索引,而代码的本质则是将这个一维索引 pid 转化为二维索引(pid_m, pid_n),而在这个变化中,我们重新定义了结果矩阵的计算顺序(即上图,右图中区别于左图的元素计算顺序)。


pid = tl.program_id(axis=0)num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)num_pid_in_group = GROUP_SIZE_M * num_pid_ngroup_id = pid // num_pid_in_groupfirst_pid_m = group_id * GROUP_SIZE_Mgroup_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)pid_m = first_pid_m + (pid % group_size_m)pid_n = (pid % num_pid_in_group) // group_size_m
复制代码


我们用两个 9x9 的矩阵相乘来说明这个索引的过程:


  • 首先 num_pid_m,num_pid_n 分别计算了两个矩阵的 M, N 维度各自块的个数;如下图中 num_pid_m=9,num_pid_n=9。

  • GROUP_SIZE_M 定义了 group 的维度;如下图中的 GROUP_SIZE_M=3。

  • num_pid_in_group 为单个 group 中块的个数;如下图中 GROUP_SIZE_M * num_pid_n=27。

  • group_id 则计算得到 pid 是在哪个 group 中,即 group 的 id;如下图中 pid // num_pid_in_group=1。

  • first_pid_m 计算的是这个 pid 所在的 group 的第一个块的 pid_m 的值,以方便后续为算 pid 最终的 pid_m, pid_n 提供偏置;如下图中 first_pid_m=3 ,即图中第 27 块的 pid_m。

  • group_size_m 则是计算了这个 pid 所在的 group 的行数,这是为了避免 M 无法为 GROUP_SIZE_M 所整除时,最后一个 group 的行数小于 GROUP_SIZE_M;如下图的 group 的行数值为 GROUP_SIZE_M,若当图中的行数为 8 时,可以想像最后一行的 group_size_m 为 2。

  • 最后两行代码则是计算 pid 的真正坐标 (pid_m, pid_n)。例如下图的 pid=33, 则 pid_m=3+(33%3)=3, pid_n=(33%27)//3=2。


旋转位置编码

旋转位置编码(RoPE, Rotate Position Encoding)是 Transformer 进入大模型应用时代后的重要算子,在 Llama,ChatGLM 等主流的大模型中都有应用。关于旋转位置编码的原理和作用可以参考原论文和作者博客。其计算过程可以简要表示成以下的旋转变换,



以下是一个 Huggingface 中 Llama RoPE 的前向计算流程。


  • d 表示 embedding 的维度,则位置编码的相位频率表示如下,m = [0, 2, 4, ..., d/2] , f = (1/10000)^m,


  • n 表示 token 的个数,







以上是计算 cos 和 sin 两个旋转变换矩阵的过程,而矩阵 q 和 k 在做注意力乘法前先做简单处理。


def rotate_half(x):    """Rotates half the hidden dims of the input."""    x1 = x[..., : x.shape[-1] // 2]    x2 = x[..., x.shape[-1] // 2 :]    return torch.cat((-x2, x1), dim=-1)
复制代码


最后是旋转变换:


q_embed = (q*cos) + (rotate_half(q)*sin)k_embed = (k*cos) + (rotate_half(k)*sin)
复制代码


我们来看 Triton kernel 对前向过程的实现:


给定矩阵 Q,Cos,Sin, 它们的维度分别为



h 是注意力模块中的 head 数,为方便说明,可令 b=1;需要实现的计算过程是




为计算方便,现将 Q reshape 为(n,hd)。


考虑并行的维度时,按 token 并行是首先能想到的,n_rows = n,再考虑到有限的缓存,可以将另一个维度按 group 分,可以定义一个常量 GROUP_SIZE=4(当然这里也可以设计成 autotune 模式,自动选择合适值),然后可以计算得到 embedding 维度的 group 数量 n_groups ,并定义我们的网格 grid。


div, mod = divmod(n, GROUP_SIZE)n_groups = div + (mod != 0)grid = (n_rows, n_groups)
复制代码


下图阐释了整体和单元的计算过程。并行程序按两个维度计算,每个 CTA 的计算过程中有个次数为 GROUP_SIZE 的循环累加,各自累加计算得到 q*cos 和 rotate_half(q)*sin,再相加。



def _rope_embedding(    Q,     Q_row_stride,    cos, cos_row_stride,    sin, sin_row_stride,    seqlen,    head_dim      : tl.constexpr,    n_heads       : tl.constexpr,    BLOCK_SIZE    : tl.constexpr,):    """        Calculates the RoPE Embedding quickly        RoPE is Q * cos + rotate_half(Q) * sin        See our blog post for more info    """    GROUP_SIZE = 4    row_position  = tl.program_id(0)    group_head_position = tl.program_id(1)    col_offsets  = tl.arange(0, BLOCK_SIZE)    half_head_dim = head_dim // 2    mask = col_offsets < half_head_dim
sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \ half_head_dim*0 + col_offsets, mask = mask, other = 0) cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \ half_head_dim*0 + col_offsets, mask = mask, other = 0)
head_start = group_head_position * GROUP_SIZE head_end = min((head_start + GROUP_SIZE), n_heads)
for k in range(head_start, head_end): offs_q1 = row_position * Q_row_stride + k * head_dim + col_offsets offs_q2 = row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim
Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype) Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)
tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask) tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)
复制代码

四、总结

参考 Triton 的官方示例、以及其他社区的开源加速工具,我们还能看到其他许多算子,诸如 rmsNorm, Softmax, Flash-Attention 等的具体加速方案和并行思路,以及他们的反向传导过程。而这些算子则组成了 Transformer 的注意力模块的整体结构。通过对各个模块的并行优化,可以实现对注意力计算的推理和训练加速。



Transformer 的注意力结构及相关算子


在实际应用中,我们可以自己编写和优化自定义算子,也可以引用社区优秀的开源加速算子库,而且 Pytorch 在 2.0+版本后已将 Triton 集成到其编译器,用 torch.compile()可以直接对已加载模型编译,帮助自动优化可优化的算子。


参考:


Triton-lang 介绍 (https://openai.com/index/triton/)


Triton-tutorial(https://triton-lang.org/main/getting-started/tutorials/index.html)


Nvidia GA100 datasheet (https://images.nvidia.cn/aem-dam/en-zz/Solutions/data-center/nvidia-ampere-architecture-whitepaper.pdf)


CUDA programming(https://docs.nvidia.com/cuda/cuda-c-programming-guide/#kernels)


文 / xujiong


关注得物技术,每周更新技术干货


要是觉得文章对你有帮助的话,欢迎评论转发点赞~


未经得物技术许可严禁转载,否则依法追究法律责任。

用户头像

得物技术

关注

得物APP技术部 2019-11-13 加入

关注微信公众号「得物技术」

评论

发布
暂无评论
Triton-Lang在Transformer优化加速中的实践 | 得物技术_人工智能_得物技术_InfoQ写作社区