CANN-ops-nn融合MatMul加LayerNorm-昇腾NPU上两个最忙算子怎么省一遍读写

大模型的每一层 Transformer 里,MatMul 和 LayerNorm 是出勤率最高的两个算子。标准实现下它们各跑各的——MatMul 在 Cube 单元算完写回 HBM,LayerNorm 从 HBM 读出来在 Vector 单元算完再写回去。ops-nn 的融合版本让这两个算子共享片上缓存,省掉中间那一轮 HBM 读写。

为什么是这两个

单层 Transformer 里 MatMul 出现的次数:

Q Linear, K Linear, V Linear     → 3 次 MatMul
Attention Output Linear          → 1 次 MatMul
Gate Linear, Up Linear, Down Linear → 3 次 MatMul

LayerNorm(或 RMSNorm)出现 2 次:Attention 前一次,FFN 前一次。

而且它们的搭配是固定的:每组 MatMul 的输出 → 激活函数 → 下一个操作 → LayerNorm。数据流是线性的,融合条件天然满足。

融合前后的数据流

标准实现:
  HBM → Cube(MatMul) → HBM → Vector(Activation) → HBM → Vector(LayerNorm) → HBM
  4 次 HBM 读写

融合实现:
  HBM → Cube(MatMul) → 片上缓存 → Vector(Activation) → 片上缓存 → Vector(LayerNorm) → HBM
  2 次 HBM 读写(只读输入、只写最终输出)

HBM 读写减半。在昇腾NPU上 HBM 带宽约 1.2 TB/s,省一次 4096×4096 的 float16 读写(32MB)就是省约 27μs 的延迟。单层数字不大,32 层叠起来就是 0.86ms。

0.86ms 听起来也不多?在 decode 阶段,每生成一个 token 只需要 3-5ms,0.86ms 就是 17-28% 的延迟优化。

ops-nn 的融合接口

import torch_npu

# MatMul + Bias + SiLU + LayerNorm 四合一
out = torch_npu.npu.fused_linear_act_norm(
    x,           # 输入 [batch, seq, hidden]
    weight,      # 权重 [hidden, ff_dim]
    bias,        # 偏置 [ff_dim]
    norm_weight, # LayerNorm 权重 [ff_dim]
    norm_bias,   # LayerNorm 偏置 [ff_dim]
    activation="silu",
    eps=1e-5
)

这个接口把四个操作压进一个 kernel。中间结果全部在 AI Core 的片上缓存里流转,不出 HBM。

不过有个限制:融合接口要求 MatMul 的输出维度和 LayerNorm 的归一化维度一致。大部分标准 Transformer 模型满足这个条件。如果你做了自定义的维度变换(比如输出后接一个 reshape 再做 LayerNorm),融合会失效。

性能对比

Atlas 800I A2,Llama2-7B decode 阶段单层:

配置 单层延迟 (ms) HBM 读写 (GB)
标准分离 3.8 1.2
MatMul+Activation 融合 3.2 0.9
MatMul+Act+LayerNorm 融合 2.6 0.6

32 层总计延迟从 121ms 降到 83ms,decode 吞吐提升约 46%。

和 ops-transformer 融合算子的关系

ops-nn 的融合粒度是"两个相邻算子"。ops-transformer 的融合粒度是"整个计算模式"(比如整个 Attention 过程)。

当 ops-transformer 的 FlashAttention 启用时,Attention 内部的 MatMul 已经被融合了,不需要 ops-nn 再介入。但 Attention 之外的 FFN 层,MatMul + SiLU + LayerNorm 的融合还是 ops-nn 在处理。

两层融合不冲突,各管各的地盘。

踩坑

LayerNorm 的 eps 参数必须显式传入。 PyTorch 的 nn.LayerNorm 默认 eps=1e-5,但融合接口如果你不传 eps,会走默认值 1e-8。精度差异在 float16 下可能被放大——训练场景影响不大,推理场景几乎无影响,但测试时精度对比可能不过。

权重排布要求连续。 融合接口要求 weight、bias、norm_weight、norm_bias 在 HBM 里连续存储。如果你的权重是通过 model.state_dict() 逐个加载的,内存不一定连续。解决方式是在加载后做一次 .contiguous() 调用。


FFN 层的 MatMul+Activation+LayerNorm 融合是"免费的午餐"——改动小、收益确定、跟 ops-transformer 不冲突。如果你的 decode 性能不够看,先检查这个融合有没有生效。仓库在这里:

https://atomgit.com/cann/ops-nn

Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐