CANN-ops-blas推理场景-昇腾NPU上decode阶段的GEMM怎么扛住低利用率
《昇腾NPU上decode阶段GEMM优化策略》摘要 针对昇腾NPU在LLM推理decode阶段GEMM利用率低的问题(仅7%),文章分析了长条形矩阵运算特征并提出优化方案。关键策略包括:1)权重预取实现计算与搬运重叠,提升HBM带宽至90%;2)KV Cache连续读取优化,采用64元素对齐存储;3)多请求打包成batch,使M维从1增至32,利用率提升至65%。实验显示batch=32时吞吐量
CANN-ops-blas推理场景-昇腾NPU上decode阶段的GEMM怎么扛住低利用率
上一篇讲了 Tiling 决定 GEMM 性能。但 decode 阶段的 MatMul 是 [1, hidden] × [hidden, ff_dim],M 维只有 1——Tiling 再怎么切也打不满 Cube 单元。ops-blas 针对这种长条形矩阵有专门的优化策略。
Decode 阶段的 GEMM 特征
Llama2-7B decode 阶段的 MatMul 参数:
QKV Linear: [1, 4096] × [4096, 3×4096] → K=4096, N=12288
O Linear: [1, 4096] × [4096, 4096] → K=4096, N=4096
Gate/Up: [1, 4096] × [4096, 28672] → K=4096, N=28672
Down: [1, 14336] × [14336, 4096] → K=14336, N=4096
M 维全是 1。Cube 单元一次算 [16×16] 的块,M=1 时只能用 1/16 的算力。理论峰值 310 TFLOPS,实际只跑出 15-20 TFLOPS,利用率不到 7%。
这不是昇腾NPU独有的问题——所有 GPU/NPU 在 decode 阶段的 MatMul 利用率都很低。NVIDIA A100 的 decode GEMM 利用率也在 5-10%。
ops-blas 的应对策略
既然 Cube 单元打不满,就把瓶颈从计算转到搬运——优化 HBM 带宽利用率。
策略 1:权重预取。 Decode 阶段的权重矩阵很大(Llama2-7B 的 FFN 权重约 7GB),每次 MatMul 都要读一遍。ops-blas 在计算当前层的权重时,用 DMA 引擎提前把下一层的权重搬到 L2 缓存。
时间线:
Layer 0: [DMA读权重1] → [Cube算] → [DMA读权重2] → [Cube算] → ...
预取优化:
Layer 0: [DMA读权重1] → [Cube算 + DMA预取权重2] → [Cube算权重2 + DMA预取权重3] → ...
计算和权重搬运重叠,HBM 带宽利用率从 50% 提升到 90%。
策略 2:KV Cache 连续读取。 Decode 阶段的 Attention 需要读 KV Cache,每次读取的 K/V 形状是 [kv_heads, seq_len, dim]。seq_len 随生成步数递增,读取模式不规律。
ops-blas 的 GEMV(矩阵-向量乘法)实现会把 K/V 的行按 64 元素对齐存储。这样 DMA 一次搬 64 个 fp16 元素(128 字节),刚好是 HBM 的一次 burst 传输。不对齐的话一次可能只搬 1-2 个元素,带宽浪费 98%。
策略 3:Multi-batch 打包。 推理服务同时处理多个请求时,把多个请求的 token 拼成一个 batch:
请求 A: [1, 4096]
请求 B: [1, 4096]
请求 C: [1, 4096]
打包后: [3, 4096] × [4096, 14336]
M 维从 1 变成 3,Cube 利用率提升 3 倍。这就是 continuous batching 的底层依据——不只是减少空闲时间,还让 GEMM 的 M 维变大。
和 ATB 的配合
ATB 的推理服务默认开启 continuous batching。它把并发的 decode 请求打包成 batch 送进 GEMM:
from atb import LLM
model = LLM("meta-llama/Llama-2-7b-hf", device="npu:0",
max_batch_size=32) # 最大 batch 32
# 多个请求并行 decode,ATB 自动打包
results = model.generate(["Hello", "Hi", "Hey"]) # 3 个请求打包成 batch=3
batch=32 时 M=32,GEMM 利用率约 60%。batch=1 时 M=1,利用率只有 7%。差距 8 倍。
性能数据
Atlas 800I A2,Llama2-7B decode:
| batch | GEMM 利用率 | decode 吞吐 (tokens/s) |
|---|---|---|
| 1 | 7% | 3,200 |
| 4 | 25% | 10,800 |
| 8 | 45% | 18,500 |
| 16 | 60% | 24,000 |
| 32 | 65% | 26,000 |
batch 从 1 到 8,吞吐提升 5.8 倍——这主要来自 GEMM 利用率的提升。
Decode 阶段的 GEMM 低利用率是硬件限制,但 continuous batching 可以从外部改善。推理服务一定要开 batch——不只是并发优化,是让底层 GEMM 的 M 维变大、Cube 利用率上来。仓库在这里:
https://atomgit.com/cann/ops-blas
更多推荐



所有评论(0)