CANN-ops-transformer-从输入到输出昇腾NPU跑了多少个融合算子
本文分析了昇腾NPU在Llama2-7B模型推理过程中算子融合的优化效果。通过将Transformer层的11个独立kernel融合为2-3个主要kernel(QKV+RotaryEmbedding融合、FlashAttention融合、FFN融合),显著减少了68%的显存读写量(从121.6GB降至38.4GB)。文章详细拆解了单层Transformer的算子调用链,对比了融合前后的计算流程差异
CANN-ops-transformer-从输入到输出昇腾NPU跑了多少个融合算子
一个 Llama2-7B 的推理请求进来,token 经过 32 层 Transformer 最终输出。每一层里,昇腾NPU到底调了 ops-transformer 仓库的哪些算子?搞清楚这个问题,推理性能调优才有方向。
单层 Transformer 的算子调用链
Llama2-7B 的每一层结构:
输入 x
→ RMSNorm
→ Q/K/V Linear
→ RotaryEmbedding
→ Self-Attention (GQA)
→ Attention Output Linear
→ 残差连接
→ RMSNorm
→ Gate/Up Linear
→ SiLU 激活
→ Down Linear
→ 残差连接
→ 输出
在 ops-transformer 融合算子介入之前,这一层有 11 个独立 kernel。融合后呢?
融合后的算子调用
Attention 部分:3 个 kernel 合成 1 个
原来:
Q Linear → K Linear → V Linear (3 个 MatMul kernel)
RotaryEmbedding (1 个逐元素 kernel)
Q·K^T → Softmax → ·V (3 个 kernel)
Attention Output Linear (1 个 MatMul kernel)
融合后:
QKV + RotaryEmbedding 融合算子 (1 个 kernel)
FlashAttention 融合算子 (1 个 kernel)
Output Linear (1 个 kernel)
Q/K/V 三个 Linear 共享同一个输入 x,MergedMatMul 把它们合成一次 Batch GEMM。RotaryEmbedding 直接塞在 QKV 的输出上,在片上缓存完成旋转。FlashAttention 把 MatMul+Softmax+MatMul 融成一个 kernel。
FFN 部分:3 个 kernel 合成 1-2 个
原来:
Gate Linear → Up Linear (2 个 MatMul kernel)
SiLU (1 个逐元素 kernel)
Gate × Up (1 个逐元素 kernel)
Down Linear (1 个 MatMul kernel)
融合后:
MergedMatMul(Gate+Up) (1 个 kernel)
SiLU + Elementwise + Down Linear (graph-autofusion 自动融合,1-2 个 kernel)
Gate 和 Up Linear 共享输入,MergedMatMul 合并。SiLU 和 elementwise multiply 如果被 graph-autofusion 捕获,可以跟 Down Linear 融成 1 个 kernel;如果没有被捕获,SiLU 和 multiply 单独跑。
数据在 NPU 内的流转
HBM(显存)
↓ DMA 搬入
AIC Cube 单元 ← MatMul 在这里算
↓ 片上缓存
AIV Vector 单元 ← Softmax/SiLU/RoPE 在这里算
↓ 片上缓存
Cube 单元 ← 下一个 MatMul
↓ DMA 搬出
HBM
融合算子的核心价值:让数据在 Cube→Vector→Cube 之间流转时不出 HBM。标准实现每一步都要写回 HBM 再读出来,融合后只在片上缓存里传递。
一次 forward pass 的 HBM 读写量对比:
| 配置 | 每层 HBM 读写 (GB) | 32 层总计 (GB) |
|---|---|---|
| 无融合 | 3.8 | 121.6 |
| ops-transformer 融合 | 1.2 | 38.4 |
减少了 68% 的显存读写。在 HBM 带宽固定的前提下,减少读写 = 减少延迟。
推理服务的完整链路
从请求到响应的全链路:
1. 请求进来,tokenization(CPU)
2. Token Embedding 查表(1 次 HBM 读取)
3. 32 层 Transformer(每层 2-3 个融合 kernel)
4. LayerNorm + Linear 输出 logits(2 个 kernel)
5. Sampling(CPU + 1 次 HBM 读取)
6. 响应返回
步骤 3 占了 85% 以上的时间。在这 85% 里,ops-transformer 的融合算子处理了其中 60% 的计算(Attention + FFN 的融合部分),剩下 40% 是 RMSNorm、残差连接这些 ops-nn 的基础算子。
KV Cache 对算子的影响
推理的 prefill 阶段(首次生成)和 decode 阶段(逐 token 生成)走不同的代码路径:
- Prefill:完整序列进 FlashAttention,Q/K/V 全量计算
- Decode:Q 只有 1 个 token,K/V 从 KV Cache 读取
Decode 阶段的 FlashAttention 实现不同。Q 是 [1, heads, 1, dim],K/V 是 [1, kv_heads, seq, dim]。这种形状下 FlashAttention 的分块策略退化为单行计算——分块没有意义,因为 Q 只有一行。
ops-transformer 对 decode 阶段有专门的 kernel 变体:flash_attention_decode。它省掉了分块逻辑,直接做一次 MatMul + Softmax + MatMul。性能比通用 FlashAttention 在 decode 形状下快 15%。
ATB 会自动判断 prefill/decode 并切换 kernel。手动调 API 时需要传 is_decode=True:
# Decode 阶段
out = torch_npu.npu.flash_attention(q_decode, k_cache, v_cache, is_decode=True)
搞清楚单层 Transformer 调了哪些算子,调优才有抓手。NPU 利用率低?先看 HBM 读写次数——如果融合没生效,读写量会多两三倍。仓库在这里:
https://atomgit.com/cann/ops-transformer
更多推荐



所有评论(0)