第一次在昇腾NPU上跑 LLaMA,profiling 出来发现 LayerNorm 占了 8% 的推理时间——明明是个"小算子",怎么比 MatMul 还慢?

后来才发现:LLaMA 用的不是 LayerNorm,是 RMSNorm(Root Mean Square Layer Normalization),但标准实现没优化到位。


LayerNorm vs RMSNorm:差在哪?

LayerNorm(标准 Transformer 用):

LN(x) = γ * (x - μ) / √(σ² + ε) + β

要算:均值 μ、方差 σ²、两个可学习参数 γ 和 β。

RMSNorm(LLaMA / LLaMA-2 / Mistral 都用):

RMSNorm(x) = (x / RMS(x)) * γ
其中 RMS(x) = √(mean(x²) + ε)

只算:均方根 RMS(x)、一个可学习参数 γ。

少了什么: 不用算均值 μ,不用减均值,不用学 β 参数。

结果: RMSNorm 的计算量比 LayerNorm 少约 30%,但对精度的影
响极小(LLaMA 全程用 RMSNorm,没人觉得它精度不够)。


标准 RMSNorm 实现的问题

看起来简单(就一个均方根 + 除法),但标准实现有几个坑:

1. 需要两次 HBM 读写

标准实现:

# 第一次读:x 从 HBM 读进来
x = x.cuda()
# 算 RMS:需要把 x 的所有元素读一遍(第二次读 HBM)
rms = torch.sqrt(torch.mean(x ** 2) + eps)
# 除法:结果写回 HBM(第一次写)
out = x / rms * gamma

问题: x 被读了两遍(一次算 mean,一次做除法),HBM 带宽浪费。

2. 数值稳定性(FP16 下容易爆)

RMS 的计算是 sqrt(mean(x²)),如果 很大,FP16 会上溢(最大值 65504,平方之后直接 inf)。

标准实现里经常看到:

x = x.to(torch.float32)  # 先转 FP32
rms = torch.sqrt(torch.mean(x ** 2) + eps)
out = (x / rms * gamma).to(torch.float16)  # 再转回 FP16

问题: 频繁转 FP32 ↔ FP16,慢。


ops-transformer 里的 RMSNorm 算子优化

1. 合并成一次 HBM 读写(Fused Kernel)

ops-transformer 的实现里,RMSNorm 整个算子只有一个 Kernel,做完所有事情:

__aicore__ void FusedRMSNorm(
    AscendC::GlobalTensor<float> &x,    // 输入(HBM)
    AscendC::GlobalTensor<float> &out,   // 输出(HBM)
    AscendC::LocalTensor<float> &gamma,  // 可学习参数(存在 UB)
    int hiddenDim
) {
    // 第一步:把 x 的一个 Tile 读进 UB(只读一次)
    auto ubX = ctx.AllocTensor<float>(/*Tile 大小*/);
    AscendC::DataCopy(ubX, x[offset], tileSize);

    // 第二步:在 UB 里算 RMS(不读 HBM)
    auto ubX2 = ubX * ubX;                 // x²(UB 内)
    float meanX2 = AscendC::Mean(ubX2);    // mean(x²)(UB 内)
    float rms = sqrt(meanX2 + eps);         // RMS(UB 内)

    // 第三步:在 UB 里算除法 + 乘 gamma(不读 HBM)
    ubX = ubX / rms * gamma;               // 结果在 UB 里

    // 第四步:写回 HBM(只写一次)
    AscendC::DataCopy(out[offset], ubX, tileSize);
}

关键: x 只从 HBM 读一次,结果只写回 HBM 一次。中间计算全在 UB(片上内存)里完成。

2. FP16 的数值稳定性(不用转 FP32)

昇腾NPU 的 Vector 核支持 FP16 的 “加和到 FP32”(类似 TensorCore 的累积方式)。

ops-transformer 的 RMSNorm 里:

  • 用 FP16 算(快)
  • 累加 mean(x²) 的时候,累加器用 FP32(不溢出)
  • 最后 sqrt 和除法也在 FP32 累加器里做

结果: 不用显式转 FP32 ↔ FP16,数值稳定性够,速度快。

3. 多核并行(按 Batch 维度切分)

RMSNorm 的计算是 逐 token 独立的(每个 token 的 RMS 只和自己的 hidden 维度有关,和别的 token 无关)。

ops-transformer 里把 Batch × SeqLen 维度切分到多个 AI Core 上:

  • 每个 AI Core 负责 4-8 个 token 的 RMSNorm
  • AI Core 之间不需要通信(每个 token 独立)
  • 线性加速比(核心数翻倍,延迟减半)

实际收益(LLaMA-2 7B,Atlas 300I Duo,Batch=16)

配置 RMSNorm 延迟 (ms) 占总推理时间比例 数值稳定性(FP16)
标准 RMSNorm(PyTorch) ~1.8 8.2% 偶尔 NaN(大模型)
+ ops-transformer 优化 ~0.6 2.7% 无 NaN(FP32 累加)
提升幅度 -67% -5.5pp ✅ 稳定

代码示例(PyTorch,调用 RMSNorm)

import torch
import torch_npu

# LLaMA-2 7B 的 RMSNorm 配置
rms_norm_config = {
    "hidden_size": 4096,
    "eps": 1e-5,
    "elementwise_affine": True,  # 有 gamma 参数
}

# 在昇腾NPU上,RMSNorm 底层走的是 ops-transformer 的 Fused Kernel
# 不需要额外配置,CANN 8.0+ 自动识别
x = torch.randn(16, 4096).npu()  # [Batch, Hidden]
gamma = torch.randn(4096).npu()   # [Hidden]

# 调用 RMSNorm(底层是 ops-transformer 的 FusedRMSNorm Kernel)
out = F.rms_norm(x, (4096,), gamma, eps=1e-5)
# 上面的调用在昇腾NPU上走的是:
#   ops-transformer 的 FusedRMSNorm(一次 HBM 读写 + FP32 累加)

一个容易踩的坑

RMSNorm 的 eps 不能设太大。

比如设 eps = 1e-3(比默认的 1e-5 大 100 倍),会导致:

  • 梯度变小(RMS 的分母变大)
  • 训练不稳定(尤其是大模型,层数深,梯度累积误差大)

经验值:

  • eps = 1e-5(LLaMA 的默认配置)✅
  • eps = 1e-6(可以,但收益很小)✅
  • eps ≥ 1e-4(不推荐)❌

如果你想在自己的模型里用 RMSNorm,或者想把现有 LayerNorm 改成 RMSNorm,去 ops-transformer 的 ops/norm/ 目录:

https://atomgit.com/cann/ops-transformer

里面有:

  • rms_norm_kernel.cpp — RMSNorm 的 Ascend C 实现(Fused Kernel)
  • rms_norm_vs_layernorm.py — RMSNorm vs LayerNorm 的精度/速度对比
  • examples/rms_norm_profiling.py — 跑这个脚本看 RMSNorm 的 Timeline

一句话总结:RMSNorm 不是"新算法",是"让 LayerNorm 少算一点"——少算均值、少一个可学习参数,把两次 HBM 读写合成一次,速度和稳定性都上去了。

昇腾NPU 上跑 LLaMA,RMSNorm 的优化在 Batch 大的时候(≥16)收益更明显——因为 HBM 带宽成为瓶颈,Fused Kernel 减少 HBM 访问的优势就出来了。

Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐