写点什么

Ascend 上的 FlashAttention 实现

作者:zjun
  • 2024-12-18
    上海
  • 本文字数:1622 字

    阅读完需:约 5 分钟

Ascend上的FlashAttention实现

1 FlashAttention

FlashAttention 是一种优化 Transformer 模型计算效率和内存使用的技术。它通过减少存储访问开销(Memory Access Cost,MAC),而非降低 FLOPS(浮点运算次数),来提升性能。

2 前述知识点

涉及到内存访问,肯定与计算的硬件架构有关系。


从 GPU 架构进行解析,参考如下博客:大模型推理加速技术的学习路线是什么首先,我们将探讨 GPU 架构,特别是其内存层次结构。我们确定了两个重要模式:计算限制(compute bound)和内存限制(memory bound),并讨论了大规模 Transformer 推理受内存限制的原因。大部分优化都基于 Transformer 推理受内存限制这一基本事实,例如只要我们提高 FLOP 利用率,就能提高效率。

2.1 GPU 架构

GPU 架构总体如下图所示:



基础部分:DRAM(动态随机存取存储器)、L2 缓存和 SM(流处理器单元)


  • 与 CPU 对比

  • SM 类似于 CPU 核心,但具有更高级的并行性;

  • L2 缓存和 DRAM 类似于 CPU 的 L2 缓存和 DRAM

  • 在 Flash Attention 论文中,L2 缓存被称为 SRAM(静态随机存取存储器)

  • A100 80G SXM

  • 08 个 SM,DRAM 容量为 80GB,有 40M L2 缓存


SM 内部包含什么?



  • L1 缓存:指令和数据

  • 张量核心:进行矩阵乘法运算的地方。回想一下,神经网络计算基本上就是巨大批量的矩阵乘法。


GPU 编程基础


在执行 model.generate(prompt)时,我们进行以下操作:


  • 内存访问:

  • 从高带宽内存(HBM)加载模型权重 -> L2 缓存 -> 传输到 SM(流处理器单元)

  • 计算:

  • 在 SM 中执行矩阵乘法,SM 请求张量核心执行计算

  • A100:

  • 108 个 SM,DRAM 容量为 80G,40M L2 缓存

  • bf16 张量核心:每秒 312 万亿浮点运算(TFLOPS)

  • DRAM 内存带宽为 2039GB/秒 = 2.039T/秒

  • 如果模型很大,我们将其分割到多个 GPU 上,比如两个由 NVLink 连接的 GPU

  • NVLink 300GB/秒 = 0.3T/秒

  • 我们大致观察了速度层次结构。尽管不能直接比较,但它们的数量级差异是我们需要优化的主要方面:

  • 312T(SM 计算) > 2.03T(DRAM 内存访问) > 0.3T=300G(NVLink 跨设备通信) > 60G(PCIe 跨设备通信)

  • 这意味着,如果我们希望速度更快,我们应该尽力:

  • 充分利用 SM

  • 减少单个 GPU 的内存访问(因为它比计算慢得多),减少 GPU 之间的通信(因为它甚至比内存访问还要慢)。


计算限制与内存限制


如何确定我们是否充分利用了 SM 呢?我们通过以下方式检查是否计算或内存限制:


定义每字节 GPU 操作 = flop / 内存带宽


  • A100 = 312 / 2.039

  • 定义计算强度 = 计算 / 内存访问

  • 如果计算强度大,说明程序更会受到计算限制;如果计算强度较小,则更受内存限制。


  • 增加批次大小会将行为从内存限制变为计算限制。

  • 内核融合:减少了内存访问操作,因为我们将多个操作合并为一个操作。

2.2 Transformer 推理

内存布局



正如我们所看到的,为了在 bf16 格式下运行一个 13B 模型,我们大约只有 10GB 的内存来存储 kv 缓存。这意味着:


  • 不能使用太大型的批次(尽管我们希望使用更大的批次大小以提高效率)

  • 也不能处理太长的序列,尽管我们确实希望能够处理长度为 100k 的序列。

3 FlashAttention 的策略

FlashAttention 的核心策略包括:


  • Tiling(平铺/切分):将注意力矩阵分解成更小的子矩阵,分别计算,确保每个子矩阵的大小适合 SRAM(静态随机存取存储器)的存储能力,从而减少对 HBM(高带宽内存)的访问。

  • Recomputation(重算):在反向传播时,不存储所有中间状态,而是在需要时重新计算,节省内存。

  • 分块 SoftMax:解决标准 SoftMax 在分块计算中的问题,确保整个 Flash Attention 的正确性。

  • 优化显存交换:减少 SRAM 与 HBM 之间的数据交换,加速计算。这些策略共同作用,使 FlashAttention 在保持计算精度的同时,显著提高计算速度和内存效率

4 Ascend 上的 FlashAttention

昇腾异构计算架构 CANN 针对昇腾 AI 处理器的片上内存和缓存大小,以及数据搬运通路,基于 Ascend C 算子编程语言优化实现 FlashAttention 融合算子,充分利用片上缓存,提升 Attention 处理性能。根据实测,在一些典型场景中 CANN 的 FlashAttention 算子相比小算子取得了 5 倍以上的性能提升,开发者可直接调用相关算子 API 接口使能大模型极致性能优化。


可参考:基于Ascend C的FlashAttention算子性能优化最佳实践-技术干货-昇腾社区

用户头像

zjun

关注

还未添加个人签名 2020-03-06 加入

还未添加个人简介

评论

发布
暂无评论
Ascend上的FlashAttention实现_Transformer_zjun_InfoQ写作社区