CANN-ops-nn融合MatMul加LayerNorm-昇腾NPU上两个最忙算子怎么省一遍读写
摘要:本文介绍了在昇腾NPU上优化Transformer模型中MatMul和LayerNorm算子的融合技术。通过将这两个高频算子合并为一个操作,减少中间结果的HBM读写次数,可降低17-28%的延迟。具体实现使用torch_npu.npu.fused_linear_act_norm接口,将MatMul、Bias、激活函数和LayerNorm四合一处理,使中间数据在片上缓存流转。测试显示32层模型
CANN-ops-nn融合MatMul加LayerNorm-昇腾NPU上两个最忙算子怎么省一遍读写
大模型的每一层 Transformer 里,MatMul 和 LayerNorm 是出勤率最高的两个算子。标准实现下它们各跑各的——MatMul 在 Cube 单元算完写回 HBM,LayerNorm 从 HBM 读出来在 Vector 单元算完再写回去。ops-nn 的融合版本让这两个算子共享片上缓存,省掉中间那一轮 HBM 读写。
为什么是这两个
单层 Transformer 里 MatMul 出现的次数:
Q Linear, K Linear, V Linear → 3 次 MatMul
Attention Output Linear → 1 次 MatMul
Gate Linear, Up Linear, Down Linear → 3 次 MatMul
LayerNorm(或 RMSNorm)出现 2 次:Attention 前一次,FFN 前一次。
而且它们的搭配是固定的:每组 MatMul 的输出 → 激活函数 → 下一个操作 → LayerNorm。数据流是线性的,融合条件天然满足。
融合前后的数据流
标准实现:
HBM → Cube(MatMul) → HBM → Vector(Activation) → HBM → Vector(LayerNorm) → HBM
4 次 HBM 读写
融合实现:
HBM → Cube(MatMul) → 片上缓存 → Vector(Activation) → 片上缓存 → Vector(LayerNorm) → HBM
2 次 HBM 读写(只读输入、只写最终输出)
HBM 读写减半。在昇腾NPU上 HBM 带宽约 1.2 TB/s,省一次 4096×4096 的 float16 读写(32MB)就是省约 27μs 的延迟。单层数字不大,32 层叠起来就是 0.86ms。
0.86ms 听起来也不多?在 decode 阶段,每生成一个 token 只需要 3-5ms,0.86ms 就是 17-28% 的延迟优化。
ops-nn 的融合接口
import torch_npu
# MatMul + Bias + SiLU + LayerNorm 四合一
out = torch_npu.npu.fused_linear_act_norm(
x, # 输入 [batch, seq, hidden]
weight, # 权重 [hidden, ff_dim]
bias, # 偏置 [ff_dim]
norm_weight, # LayerNorm 权重 [ff_dim]
norm_bias, # LayerNorm 偏置 [ff_dim]
activation="silu",
eps=1e-5
)
这个接口把四个操作压进一个 kernel。中间结果全部在 AI Core 的片上缓存里流转,不出 HBM。
不过有个限制:融合接口要求 MatMul 的输出维度和 LayerNorm 的归一化维度一致。大部分标准 Transformer 模型满足这个条件。如果你做了自定义的维度变换(比如输出后接一个 reshape 再做 LayerNorm),融合会失效。
性能对比
Atlas 800I A2,Llama2-7B decode 阶段单层:
| 配置 | 单层延迟 (ms) | HBM 读写 (GB) |
|---|---|---|
| 标准分离 | 3.8 | 1.2 |
| MatMul+Activation 融合 | 3.2 | 0.9 |
| MatMul+Act+LayerNorm 融合 | 2.6 | 0.6 |
32 层总计延迟从 121ms 降到 83ms,decode 吞吐提升约 46%。
和 ops-transformer 融合算子的关系
ops-nn 的融合粒度是"两个相邻算子"。ops-transformer 的融合粒度是"整个计算模式"(比如整个 Attention 过程)。
当 ops-transformer 的 FlashAttention 启用时,Attention 内部的 MatMul 已经被融合了,不需要 ops-nn 再介入。但 Attention 之外的 FFN 层,MatMul + SiLU + LayerNorm 的融合还是 ops-nn 在处理。
两层融合不冲突,各管各的地盘。
踩坑
LayerNorm 的 eps 参数必须显式传入。 PyTorch 的 nn.LayerNorm 默认 eps=1e-5,但融合接口如果你不传 eps,会走默认值 1e-8。精度差异在 float16 下可能被放大——训练场景影响不大,推理场景几乎无影响,但测试时精度对比可能不过。
权重排布要求连续。 融合接口要求 weight、bias、norm_weight、norm_bias 在 HBM 里连续存储。如果你的权重是通过 model.state_dict() 逐个加载的,内存不一定连续。解决方式是在加载后做一次 .contiguous() 调用。
FFN 层的 MatMul+Activation+LayerNorm 融合是"免费的午餐"——改动小、收益确定、跟 ops-transformer 不冲突。如果你的 decode 性能不够看,先检查这个融合有没有生效。仓库在这里:
https://atomgit.com/cann/ops-nn
更多推荐




所有评论(0)