LLM 推理加速:decode 阶段的 Attention 在 GPU 上的优化
作者:董纪莹
随着大语言模型(Large Language Models,LLMs)在各领域的广泛应用,如何以低成本构建高吞吐、低延迟的推理服务成为了一个紧迫的问题。考虑到 LLM 在 GPU 上推理时参数量和计算量较大以致于单流执行就可以充分利用 GPU 资源,我们可以把 LLM 的推理延时分解到 kernel level,因此,进一步的,不考虑时间占比小的 kernel 计算后,LLM 的延时优化也就相应的分解成 GEMM 和 Attention 的 kernel 优化。
RTP-LLM 是阿里巴巴智能引擎团队开发的大模型推理加速引擎,作为一个高性能的大模型推理解决方案,它已被广泛应用于阿里内部。在这篇文章里,我们将基于 RTP-LLM 的实践,介绍 decode 阶段的 Attention 在 GPU 上是如何优化的。
一、背景
我们比较熟悉的 Attention 计算如下图所示,包含 Q 与 K 相乘,其结果在 mask 后做 softmax,然后与 V 相乘,得到 Attention 的结果。在 LLM 推理的 decode 阶段,由于加入 KV Cache 优化,一次迭代只需要计算新增的一个 token,因此计算可以变化为当前 step 的 Q(seq == 1)与 K Cache、V Cache 做计算。
计算过程中各 tensor 的 shape 可以表示为:、
参数的解释如下表:
在本文的分析中,我们考虑简单的 Multi Head Attention 实现,即 H == H_kv。
我们希望以一个 kernel 实现上图的计算。出于性能考虑,将前一步的 BiasAdd,Rotary Embedding 也一起融合。因此这个 kernel 接受的输入是经过 QKV GEMM 的 Q、K、V,在 kernel 中完成 BiasAdd,然后 Q 和 K 会一起做 Rotary Embedding。当前的 K 和 V 会分别与之前计算得到的 KV Cache 做拼接,扩展成(B, H, S, D)的 KV Cache。然后 Q 与 K Cache 相乘,得到的结果在 S 维计算 SoftMax,再与 V Cache 相乘,得到最后的输出。
简化的代码示例如下:
在整个计算过程中,BiasAdd、Rotary Embedding 相对计算量较小,对 kernel 的 latency 影响较小,因此下文省略这一部分的分析。
二、计算分析
我们以当前的 TensorRT-LLM 中 Masked Multi Head Attention(MMHA)的实现为例,分析当前的 MMHA 是怎么实现高性能。
涉及到 GPU 并行计算,我们首先需要考虑的是任务划分。对于这个场景,任务划分实际上是清晰的:B 和 H 是并行维度,在执行过程中的 Q*K 和 QK*V,都可以理解成一个 batch size = B * H 的 Batch GEMV。而 SoftMax 又是一个 Reduce 操作,因此单个 GEMV 的计算最好尽量在一个 block 内完成。因此,MMHA 比较基础的任务划分大概是:
这里的 THREAD_PER_BLOCK 是指每个 block 用多少 threads 来完成一个 head 在 S 上的计算。通常更多的 threads 会更提高每个 SM 的 active warps 以更好的利用计算资源,增加 load 指令以提高数据 load 效率,因此我们希望 THREAD_PER_BLOCK 越大越好(最好接近 1024)。但由于 kernel 整体计算逻辑较为复杂,寄存器用量较大,threads 可能会收到寄存器总量的限制;且在寄存器总量的限制下,我们可以简单的认为每个 SM 上只有一个 active block。
基于这种划分,我们继续考虑每个 block 是如何计算。传入 kernel 的 QKV buffer 实际的 layout 是(B,3, H,D),在 TensorRT-LLM 的实现中,会先 load 当前 step 的 Q 和 K 并计算 BiasAdd 和 ROPE,并将这一步得到的 K Cache 写回 global buffer。完成这些计算后,因为数据还在寄存器中,会直接计算对应的 QK dot。由于这些计算的耗时较短,我们略过这一部分分析,直接看看 TensorRT-LLM 是怎么计算 Q * K Cache 的。
Q 乘 K Cache 的计算在 D 上累加。假设我们用 half 存 KV Cache,用 float 做乘累加,为了保证 load 效率,每个 thread 会 load 连续的 16bytes 数据,也就是 8 个 elements。对于常见的 D==128 来说,需要 16 个 threads 完成一个 head 的计算。可以认为给 block 中的 threads 进行了分组,每组 16 个 threads 负责一个 head 的计算,其中每个 threads 读 8 个 elements,并完成这 8 个 elements 对应的乘累加,然后这组 threads 间通过 warp 内的 shuffle 完成当前 head 的计算,并将计算结果存到 smem 中。组和组在 S 上展开。
接下来计算 SoftMax,由于前面的计算保证了 SoftMax 需要的输入都在当前 block 内的 smem 中,通过 Block Reduce Max 和 Block Reduce Sum 就可以完成 SoftMax 的计算。
乘 V Cache 的计算思路与上文乘 K Cache 非常类似,略有不同的是这一步计算需要在 S 上累加。依然将 threads 分组,每组 16 个 threads 负责一个 head, 每个 thread 负责 8 个 elements 的计算。由于需要在 S 上累加,因此每个 thread 需要保存当前所计 GPUsde 算的 8 个 elements 的部分累加和。最后借助 smem,将不同 threads 上的部分和累加,得到 Attention 的输出。
在计算过程中,qk dot 除了 hfma 计算外,也可以调用 hmma 来完成单个 head 的计算。但由于 kernel 的性能瓶颈在访存上,dot 用哪种计算方式对性能的影响不大;我们的测试也验证了这个结论。
上文的分析中依然省略了一些细节。具体的,比如我们现在通常用 paged KV Block Array 来存储 KV Cache,也就是 KV Cache 可以在 S 维度上不连续,以便在 S 不断增长时动态的分配 buffer。但 paged 的存储并不改变 D 维的连续,因此也不影响上文的分析。此外,每个 thread 在 load KV Cache 时会多 load 一部分存进本地的寄存器,以尽可能的将 load 数据与 dot 计算 overlap。
主流框架如 vllm,xformers 等对 MMHA 的实现和优化思路都是比较类似的,仅在细节处略有差异。TensorRT-LLM 在 mmha 外还实现了 XQA 以继续优化 decode 阶段 Attention 的计算,但由于代码未开源,本文也不做分析。
三、改进与优化
当然上文分析到的简单优化在实际应用中还是不那么够用的,最常见的就是小 B 和长 S 场景。
考虑到实际的 GPU 资源,如 A100 有 108 个 SM,且每个 SM 上只有一个 block(也就是只计算一个 head),当 B * H 恰好占满 108(或 108 的整数倍)个 SM 时,可以认为占用率是比较高的。以 7B 模型,或者 72B 模型 2TP 举例,H = 32,当 B = 3 时,占用率是 88.9%;而当 B = 4 时,就会因必须打两轮而带来占用率的下降到 59%;当 B = 1 时,占用率就会低到 30%了。这个时候如果 S 比较大,我们就会发现,大部分的 device 资源还空闲着,也不得不一起等待部分 SM 完成一个时间很长的计算。
针对这种情况,我们把 S 也分配到 grid dim 上,资源分配也就改为:
在这种任务划分下,结合上文分析,假设长 seq 每个 SM 上仅有一个 active block,则 waves 可以计算为:
当 waves 越接近 ceil 值,意味着 device occupancy 会越高。在小 B 大 S 的场景下,如果在 S 切分,也就是 S_tile > 1,有利于增加 occupancy。在这种情况下,S_tile 个 block 共同完成一个 head 在 S 上的计算,每个 block 负责 S / S_tile 的计算,block 间的 reduce 通过开辟额外的 global buffer 来完成。这种模式下,新增的 global 读写会带来有额外的耗时,但因为增加了 device occupancy,因此在小 B 大 S 的场景下有明显的性能提升。这也就是 flash decoding 的思路,且在各框架均有支持。
除了性能的考虑外,超长 seq 也必须走进这种实现。由于 Q * K 的结果需要在 S 上做 reduce,也就是 smem 需要存下对应大小的中间数据,根据 kernel 实现,输入类型是 half,以 float 累加,可以估计算为 6 * S。而根据 A100 每个 SM 实际可用 smem 是 163KB 计算,最大可支持的 S 在 27K 左右。当输入大于这个值时,我们必须在 seq 做切分,以保证 kernel 的计算。
另一种需要做不同的任务划分的场景是 GQA。在 GQA 的计算下,每个 head 的 KV Cache 会对应于多个 head 的 Q,为了避免 KV Cache 的重复 load,资源分配应该改为,并基于此做计算上的调整。
除了优化任务划分,MMHA 的优化还可以在以下方面继续展开:
1)优化寄存器用量可能达到更高的占用率(可以在一个 SM 上 launch 多个 block 或者增大每个 block 的 threads);
2)继续调整 KV Cache 的 load 行为,让计算和数据读取进一步 overlap 以缓解 memory bound 的场景;
3)在大 B 加上 GQA,Attention 会走到 compute bound,需要调整计算模式以更好的利用 tensor core 加速计算等等。
我们将持续探索和实践,以更灵活、更具拓展性的优化策略来面对日益多样化和复杂的应用场景。优化后的 kernel 会开源在 RTP-LLM 中,欢迎大家交流共建。
参考链接
[01] TensorRT-LLM
https://github.com/NVIDIA/TensorRT-LLM
[02] vllm
https://github.com/vllm-project/vllm
[03] xformers
https://github.com/facebookresearch/xformers
[04] flash decoding
https://crfm.stanford.edu/2023/10/12/flashdecoding.html
[03] RTP-LLM
版权声明: 本文为 InfoQ 作者【阿里技术】的原创文章。
原文链接:【http://xie.infoq.cn/article/78e8a05217596c789f2345a1f】。文章转载请联系作者。
评论