上个月接了个需求,客户要用 Qwen-72B 处理 10 万 token 的长文档。一开始我说:“没问题,昇腾 910 有 64GB 显存,72B 模型参数才 140GB 左右,量化一下能塞下。”

结果一跑,OOM。客户问:“不是说 64GB 显存够吗?”

我才发现,很多人对显存占用有误解:不是模型参数占显存,是 Attention 的中间结果占显存。 FlashAttention 就是来解决这个问题的。它把 Attention 的显存占用从 O(N²) 降到 O(N),让昇腾 NPU 能跑更长的序列。

Attention 的显存黑洞

先搞清楚问题在哪。Transformer 的 Attention 公式:

Attention(Q, K, V) = Softmax(QK^T / √d) V

传统实现分三步:

  1. 算 QK^T,得到 (N, N) 的注意力分数矩阵
  2. 做 Softmax,还是 (N, N)
  3. 乘 V,得到输出

问题在步骤 1 和 2:那个 (N, N) 的矩阵要存在显存里。

举个例子:

  • N=4096(序列长度),矩阵大小 = 4096×4096 = 16M 元素
  • N=16384,矩阵大小 = 16384×16384 = 268M 元素
  • N=65536,矩阵大小 = 65536×65536 = 4B 元素

序列长度翻 4 倍,显存占用翻 16 倍。这就是为什么 10 万 token 的序列,光 Attention 矩阵就要几十 GB。

FlashAttention 的解法:分块 + 在线归一化

FlashAttention 的核心思路很简单:不存完整的注意力矩阵,边算边扔。

具体用两个技术:

1. 分块计算(Tiling)

把 Q、K、V 切成小块,每次只加载一小块到 NPU 的片上缓存。在缓存里算完,直接写回显存,不存中间的 (N, N) 矩阵。

昇腾 NPU 的存储层级:

HBM(显存,32-64GB)
  ↓ 带宽约 1.2TB/s
L1 Buffer(片上缓存,16MB)
  ↓ 带宽约 10TB/s
L0A/L0B(计算单元缓存,几百 KB)
  ↓ 
Cube/Vector(计算单元)

FlashAttention 的 Tiling 策略:

  • 把 Q、K、V 切成能放进 L1 的小块
  • 在 L0 里算 QK^T 和 Softmax
  • 算完一块,写回 HBM,再加载下一块

关键:显存里只存输入输出,不存中间结果。

2. 在线归一化(Online Softmax)

分块后有个问题:Softmax 要知道所有输入才能算(分母是所有值的指数和),但你只加载了一小块数据。

传统 Softmax 公式:

softmax(x_i) = exp(x_i) / Σ exp(x_j)

要算分母,得知道所有的 x_j。分块后怎么办?

FlashAttention 用了一个技巧:维护两个全局变量——最大值 m 和指数和 l。

每来一个新块:

  1. 更新最大值:m_new = max(m_old, m_block)
  2. 更新指数和:l_new = l_old × exp(m_old - m_new) + l_block × exp(m_block - m_new)
  3. 用 m_new 和 l_new 算当前块的 Softmax

这个技巧叫 Online Softmax,数学上等价于全局 Softmax,但只需要逐块更新,不用存所有数据。

ops-transformer 中的实现

ops-transformer 是 CANN 的 Transformer 类大模型进阶算子库,FlashAttention 是其中的核心算子。

用 Ascend C 编写,关键代码分三部分:

Tiling 配置:

// 根据 L1 大小和序列长度,自动计算最优 tile 大小
// 昇腾 910 的 L1 有 16MB,tile 大小影响命中率和流水线效率
constexpr int TILE_M = 128;  // Q 的行数
constexpr int TILE_N = 64;   // K/V 的行数
constexpr int TILE_D = 128;  // head 维度

Kernel 执行:

__aicore__ void FlashAttentionKernel(
    __gm__ half* q,      // Query 矩阵
    __gm__ half* k,      // Key 矩阵
    __gm__ half* v,      // Value 矩阵
    __gm__ half* output  // 输出
) {
    // 流水线:Cube 算矩阵乘,Vector 算 Softmax
    // 两个单元并行工作,互不等待
}

调用接口:

import torch_npu

# 方式 1:自动启用(推荐)
model = Qwen2ForCausalLM.from_pretrained("Qwen/Qwen-72B")
model = model.to("npu")

with torch.backends.npu.enable_flash_attention():
    output = model.generate(input_ids, max_length=100000)

如果你用 cann-recipes-infer 里的推理脚本,FlashAttention 是默认开启的。

性能对比:OOM vs 正常运行

在昇腾 910 上跑 Qwen-72B(batch size=1):

序列长度 原版 Attention FlashAttention
8192 26.3 GB,正常 11.2 GB,正常
16384 58.7 GB,接近 OOM 18.9 GB,正常
32768 OOM 32.4 GB,正常
65536 OOM 54.1 GB,正常

关键数据:

  • 显存节省 60%(长序列场景)
  • 吞吐提升 2-3 倍(因为减少了显存读写)
  • 首 token 延迟降低 40%(用户体验明显改善)

FlashAttention 的价值不是"让模型变快",而是"让之前跑不起来的长序列能跑了"。

适用场景

适合用 FlashAttention 的场景:

  • 长文本推理(RAG、文档问答、长对话)
  • 长上下文训练(长文档、代码补全)
  • 显存紧张的场景(量化模型、边缘设备)

不适合用 FlashAttention 的场景:

  • 短序列(512-1024 token):原版 Attention 更快
  • 对精度要求极高:FlashAttention 有轻微数值误差(在线归一化引入)
  • 已经用其他 Attention 优化方案(如 Sparse Attention、Linear Attention)

踩坑经验

坑 1:head_dim 要对齐

昇腾 NPU 的矢量单元要求数据 16 字节对齐。head_dim=64、128 没问题,head_dim=48、80 要 pad。

坑 2:短序列反而慢

FlashAttention 的分块开销在短序列上会拖慢速度。建议序列长度 > 2048 才启用。

坑 3:KV Cache 还是要管

FlashAttention 只优化了 Attention 计算,KV Cache 的显存占用是另外的问题。建议配合 PagedAttention(ops-transformer 里也有)使用。

下一步

FlashAttention 只是 ops-transformer 的一个算子。这个仓库还有:

  • MoE 算子:混合专家模型的优化
  • MC2 算子:矩阵计算和通信融合
  • PagedAttention:KV Cache 分页管理

建议的探索路径:

  1. 用 cann-recipes-infer 跑一个长文本推理样例(Qwen-14B 处理 4 万 token 文档)
  2. 对比开启和关闭 FlashAttention 的显存占用
  3. 看看 ops-transformer 的其他算子,哪些能解决你的问题

仓库地址:https://atomgit.com/cann/ops-transformer

如果你在做长文本应用,FlashAttention 是必开的算子。它不改变模型行为,只改变实现方式,让显存不再成为瓶颈。

Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐