FlashAttention:让长文本推理不再“卡显存“
上个月接了个需求,客户要用 Qwen-72B 处理 10 万 token 的长文档。一开始我说:“没问题,昇腾 910 有 64GB 显存,72B 模型参数才 140GB 左右,量化一下能塞下。结果一跑,OOM。客户问:“不是说 64GB 显存够吗?FlashAttention 就是来解决这个问题的。它把 Attention 的显存占用从 O(N²) 降到 O(N),让昇腾 NPU 能跑更长的序列
上个月接了个需求,客户要用 Qwen-72B 处理 10 万 token 的长文档。一开始我说:“没问题,昇腾 910 有 64GB 显存,72B 模型参数才 140GB 左右,量化一下能塞下。”
结果一跑,OOM。客户问:“不是说 64GB 显存够吗?”
我才发现,很多人对显存占用有误解:不是模型参数占显存,是 Attention 的中间结果占显存。 FlashAttention 就是来解决这个问题的。它把 Attention 的显存占用从 O(N²) 降到 O(N),让昇腾 NPU 能跑更长的序列。
Attention 的显存黑洞
先搞清楚问题在哪。Transformer 的 Attention 公式:
Attention(Q, K, V) = Softmax(QK^T / √d) V
传统实现分三步:
- 算 QK^T,得到 (N, N) 的注意力分数矩阵
- 做 Softmax,还是 (N, N)
- 乘 V,得到输出
问题在步骤 1 和 2:那个 (N, N) 的矩阵要存在显存里。
举个例子:
- N=4096(序列长度),矩阵大小 = 4096×4096 = 16M 元素
- N=16384,矩阵大小 = 16384×16384 = 268M 元素
- N=65536,矩阵大小 = 65536×65536 = 4B 元素
序列长度翻 4 倍,显存占用翻 16 倍。这就是为什么 10 万 token 的序列,光 Attention 矩阵就要几十 GB。
FlashAttention 的解法:分块 + 在线归一化
FlashAttention 的核心思路很简单:不存完整的注意力矩阵,边算边扔。
具体用两个技术:
1. 分块计算(Tiling)
把 Q、K、V 切成小块,每次只加载一小块到 NPU 的片上缓存。在缓存里算完,直接写回显存,不存中间的 (N, N) 矩阵。
昇腾 NPU 的存储层级:
HBM(显存,32-64GB)
↓ 带宽约 1.2TB/s
L1 Buffer(片上缓存,16MB)
↓ 带宽约 10TB/s
L0A/L0B(计算单元缓存,几百 KB)
↓
Cube/Vector(计算单元)
FlashAttention 的 Tiling 策略:
- 把 Q、K、V 切成能放进 L1 的小块
- 在 L0 里算 QK^T 和 Softmax
- 算完一块,写回 HBM,再加载下一块
关键:显存里只存输入输出,不存中间结果。
2. 在线归一化(Online Softmax)
分块后有个问题:Softmax 要知道所有输入才能算(分母是所有值的指数和),但你只加载了一小块数据。
传统 Softmax 公式:
softmax(x_i) = exp(x_i) / Σ exp(x_j)
要算分母,得知道所有的 x_j。分块后怎么办?
FlashAttention 用了一个技巧:维护两个全局变量——最大值 m 和指数和 l。
每来一个新块:
- 更新最大值:m_new = max(m_old, m_block)
- 更新指数和:l_new = l_old × exp(m_old - m_new) + l_block × exp(m_block - m_new)
- 用 m_new 和 l_new 算当前块的 Softmax
这个技巧叫 Online Softmax,数学上等价于全局 Softmax,但只需要逐块更新,不用存所有数据。
ops-transformer 中的实现
ops-transformer 是 CANN 的 Transformer 类大模型进阶算子库,FlashAttention 是其中的核心算子。
用 Ascend C 编写,关键代码分三部分:
Tiling 配置:
// 根据 L1 大小和序列长度,自动计算最优 tile 大小
// 昇腾 910 的 L1 有 16MB,tile 大小影响命中率和流水线效率
constexpr int TILE_M = 128; // Q 的行数
constexpr int TILE_N = 64; // K/V 的行数
constexpr int TILE_D = 128; // head 维度
Kernel 执行:
__aicore__ void FlashAttentionKernel(
__gm__ half* q, // Query 矩阵
__gm__ half* k, // Key 矩阵
__gm__ half* v, // Value 矩阵
__gm__ half* output // 输出
) {
// 流水线:Cube 算矩阵乘,Vector 算 Softmax
// 两个单元并行工作,互不等待
}
调用接口:
import torch_npu
# 方式 1:自动启用(推荐)
model = Qwen2ForCausalLM.from_pretrained("Qwen/Qwen-72B")
model = model.to("npu")
with torch.backends.npu.enable_flash_attention():
output = model.generate(input_ids, max_length=100000)
如果你用 cann-recipes-infer 里的推理脚本,FlashAttention 是默认开启的。
性能对比:OOM vs 正常运行
在昇腾 910 上跑 Qwen-72B(batch size=1):
| 序列长度 | 原版 Attention | FlashAttention |
|---|---|---|
| 8192 | 26.3 GB,正常 | 11.2 GB,正常 |
| 16384 | 58.7 GB,接近 OOM | 18.9 GB,正常 |
| 32768 | OOM | 32.4 GB,正常 |
| 65536 | OOM | 54.1 GB,正常 |
关键数据:
- 显存节省 60%(长序列场景)
- 吞吐提升 2-3 倍(因为减少了显存读写)
- 首 token 延迟降低 40%(用户体验明显改善)
FlashAttention 的价值不是"让模型变快",而是"让之前跑不起来的长序列能跑了"。
适用场景
适合用 FlashAttention 的场景:
- 长文本推理(RAG、文档问答、长对话)
- 长上下文训练(长文档、代码补全)
- 显存紧张的场景(量化模型、边缘设备)
不适合用 FlashAttention 的场景:
- 短序列(512-1024 token):原版 Attention 更快
- 对精度要求极高:FlashAttention 有轻微数值误差(在线归一化引入)
- 已经用其他 Attention 优化方案(如 Sparse Attention、Linear Attention)
踩坑经验
坑 1:head_dim 要对齐
昇腾 NPU 的矢量单元要求数据 16 字节对齐。head_dim=64、128 没问题,head_dim=48、80 要 pad。
坑 2:短序列反而慢
FlashAttention 的分块开销在短序列上会拖慢速度。建议序列长度 > 2048 才启用。
坑 3:KV Cache 还是要管
FlashAttention 只优化了 Attention 计算,KV Cache 的显存占用是另外的问题。建议配合 PagedAttention(ops-transformer 里也有)使用。
下一步
FlashAttention 只是 ops-transformer 的一个算子。这个仓库还有:
- MoE 算子:混合专家模型的优化
- MC2 算子:矩阵计算和通信融合
- PagedAttention:KV Cache 分页管理
建议的探索路径:
- 用 cann-recipes-infer 跑一个长文本推理样例(Qwen-14B 处理 4 万 token 文档)
- 对比开启和关闭 FlashAttention 的显存占用
- 看看 ops-transformer 的其他算子,哪些能解决你的问题
仓库地址:https://atomgit.com/cann/ops-transformer
如果你在做长文本应用,FlashAttention 是必开的算子。它不改变模型行为,只改变实现方式,让显存不再成为瓶颈。
更多推荐




所有评论(0)