CANN-ops-math高精度Softmax-昇腾NPU上float32归一化为什么不能省
摘要:大模型推理中Softmax运算必须使用float32而非float16,主要原因是float16的指数运算容易溢出且精度不足。float16在减最大值时会产生较大误差,导致后续exp运算精度下降4%,而float32误差仅1e-6量级。昇腾NPU的ops-math实现通过两次高效Cast保持高精度计算,FlashAttention也保留这一步骤。bf16虽无溢出问题但精度不足,仍需float
CANN-ops-math高精度Softmax-昇腾NPU上float32归一化为什么不能省
大模型推理大多用 float16 跑,唯独 Softmax 这一步必须升到 float32。不是开发者保守,是 float16 的指数运算真的会溢出。ops-math 的高精度 Softmax 在昇腾NPU上做了什么,为什么融合算子里也得留这一步,这篇讲清楚。
float16 Softmax 的溢出问题
float16 的最大值是 65504。Softmax 的计算过程:
1. 减最大值:x_i - max(x) → 所有值 ≤ 0
2. 指数运算:exp(x_i - max(x)) → 所有值 ≤ 1
3. 归一化:exp_i / sum(exp)
第 1 步保证指数运算不溢出。但如果跳过第 1 步直接算 exp,float16 下只要 x > 11 就溢出(exp(11) ≈ 59874,接近 65504)。
问题出在减最大值本身也用 float16。假设序列长度 4096,logits 的方差约 8(初始化后的典型值),最大值可能到 30-40。float16 减法在大值附近的精度约 0.06,减完的结果在 [-40, 0] 区间,精度只有 0.06。这个误差在 exp 之后被放大——exp(-0.06) / exp(0) ≈ 0.94,4% 的相对误差。
float32 减法的精度约 1e-7,exp 后的误差在 1e-6 量级。差了四个数量级。
ops-math 的实现
ops-math 的 Softmax 标准流程:
1. 输入 float16 → Cast → float32
2. float32 减最大值
3. float32 指数运算
4. float32 归一化
5. 输出 float32 → Cast → float16
两次 Cast 看起来浪费,但昇腾NPU的 Vector 单元做 Cast 只需 1-2 个时钟周期,跟 exp 运算相比可以忽略。
FlashAttention 里的 Softmax
FlashAttention 把 Softmax 融合进了 Attention kernel,但 Softmax 的 float32 精度保持没有省:
FlashAttention kernel 内部:
Q·K^T → float16 结果
→ Cast float32
→ 减最大值(float32)
→ 指数运算(float32)
→ 归一化(float32)
→ Cast float16
→ 乘 V(float16)
这也是为什么 FlashAttention 的显存占用比标准 Attention 少——中间的 Softmax 结果不需要存到 HBM,在片上缓存完成 float32 计算后直接 Cast 回 float16 继续算。
bf16 能不能绕过这个问题
bf16 的指数范围跟 float32 一样大(±3.4×10^38),但精度只有 2 位十进制。bf16 的 Softmax 问题不是溢出,是精度:
bf16 减法精度:约 0.02(在 [-40, 0] 区间)
fp32 减法精度:约 1e-7
0.02 的误差在 exp 后约 2% 的相对误差。对于 Attention 权重来说,2% 的误差会直接影响 token 的关注度排序——Softmax 本来就是为了区分"该看哪里",2% 的噪声可能让模型关注到错误的 token。
所以 bf16 也得升 fp32 做 Softmax。结论:不管输入是 fp16 还是 bf16,Softmax 必须用 fp32 算。
在线 Softmax vs 两遍 Softmax
标准 Softmax 是两遍的:第一遍找最大值,第二遍算指数和归一化。FlashAttention 用的是在线 Softmax(Online Softmax),一遍完成:
在线 Softmax:
逐块处理,维护运行中的最大值 m 和累积和 l
每处理一个新块:
m_new = max(m_old, max(new_block))
l_new = l_old * exp(m_old - m_new) + sum(exp(new_block - m_new))
O_new = (O_old * l_old * exp(m_old - m_new) + new_block_result) / l_new
在线 Softmax 的核心优势是只需要一遍扫描,配合 FlashAttention 的分块计算正好。但数学上它引入了额外的乘法和除法来修正之前的累积结果——这些修正也必须在 float32 下完成。
性能影响
float32 Softmax 的额外开销:
| 步骤 | float16 版本 | float32 版本 |
|---|---|---|
| 减最大值 | 0.008ms | 0.008ms |
| 指数运算 | 0.012ms | 0.014ms |
| 归一化 | 0.006ms | 0.006ms |
| Cast × 2 | 0 | 0.002ms |
| 总计 | 0.026ms | 0.030ms |
float32 只慢 15%,但精度提升四个数量级。这笔交易怎么算都值。
Softmax 用 float32 不是保守,是数学上的必然。float16 的指数精度不够,bf16 也一样。ops-math 的高精度 Softmax 在融合算子内部自动处理,用户不需要手动 Cast。但如果你自己写 Softmax 相关的算子,记住这个规则。仓库在这里:
https://atomgit.com/cann/ops-math
更多推荐




所有评论(0)