CANN-ops-blas推理场景-昇腾NPU上decode阶段的GEMM怎么扛住低利用率

上一篇讲了 Tiling 决定 GEMM 性能。但 decode 阶段的 MatMul 是 [1, hidden] × [hidden, ff_dim],M 维只有 1——Tiling 再怎么切也打不满 Cube 单元。ops-blas 针对这种长条形矩阵有专门的优化策略。

Decode 阶段的 GEMM 特征

Llama2-7B decode 阶段的 MatMul 参数:

QKV Linear:  [1, 4096] × [4096, 3×4096]  → K=4096, N=12288
O Linear:    [1, 4096] × [4096, 4096]   → K=4096, N=4096
Gate/Up:     [1, 4096] × [4096, 28672]  → K=4096, N=28672
Down:        [1, 14336] × [14336, 4096] → K=14336, N=4096

M 维全是 1。Cube 单元一次算 [16×16] 的块,M=1 时只能用 1/16 的算力。理论峰值 310 TFLOPS,实际只跑出 15-20 TFLOPS,利用率不到 7%。

这不是昇腾NPU独有的问题——所有 GPU/NPU 在 decode 阶段的 MatMul 利用率都很低。NVIDIA A100 的 decode GEMM 利用率也在 5-10%。

ops-blas 的应对策略

既然 Cube 单元打不满,就把瓶颈从计算转到搬运——优化 HBM 带宽利用率。

策略 1:权重预取。 Decode 阶段的权重矩阵很大(Llama2-7B 的 FFN 权重约 7GB),每次 MatMul 都要读一遍。ops-blas 在计算当前层的权重时,用 DMA 引擎提前把下一层的权重搬到 L2 缓存。

时间线:
Layer 0: [DMA读权重1] → [Cube算] → [DMA读权重2] → [Cube算] → ...
预取优化:
Layer 0: [DMA读权重1] → [Cube算 + DMA预取权重2] → [Cube算权重2 + DMA预取权重3] → ...

计算和权重搬运重叠,HBM 带宽利用率从 50% 提升到 90%。

策略 2:KV Cache 连续读取。 Decode 阶段的 Attention 需要读 KV Cache,每次读取的 K/V 形状是 [kv_heads, seq_len, dim]。seq_len 随生成步数递增,读取模式不规律。

ops-blas 的 GEMV(矩阵-向量乘法)实现会把 K/V 的行按 64 元素对齐存储。这样 DMA 一次搬 64 个 fp16 元素(128 字节),刚好是 HBM 的一次 burst 传输。不对齐的话一次可能只搬 1-2 个元素,带宽浪费 98%。

策略 3:Multi-batch 打包。 推理服务同时处理多个请求时,把多个请求的 token 拼成一个 batch:

请求 A: [1, 4096]
请求 B: [1, 4096]
请求 C: [1, 4096]
打包后:  [3, 4096] × [4096, 14336]

M 维从 1 变成 3,Cube 利用率提升 3 倍。这就是 continuous batching 的底层依据——不只是减少空闲时间,还让 GEMM 的 M 维变大。

和 ATB 的配合

ATB 的推理服务默认开启 continuous batching。它把并发的 decode 请求打包成 batch 送进 GEMM:

from atb import LLM

model = LLM("meta-llama/Llama-2-7b-hf", device="npu:0",
            max_batch_size=32)  # 最大 batch 32

# 多个请求并行 decode,ATB 自动打包
results = model.generate(["Hello", "Hi", "Hey"])  # 3 个请求打包成 batch=3

batch=32 时 M=32,GEMM 利用率约 60%。batch=1 时 M=1,利用率只有 7%。差距 8 倍。

性能数据

Atlas 800I A2,Llama2-7B decode:

batch GEMM 利用率 decode 吞吐 (tokens/s)
1 7% 3,200
4 25% 10,800
8 45% 18,500
16 60% 24,000
32 65% 26,000

batch 从 1 到 8,吞吐提升 5.8 倍——这主要来自 GEMM 利用率的提升。


Decode 阶段的 GEMM 低利用率是硬件限制,但 continuous batching 可以从外部改善。推理服务一定要开 batch——不只是并发优化,是让底层 GEMM 的 M 维变大、Cube 利用率上来。仓库在这里:

https://atomgit.com/cann/ops-blas

Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐