CANN-ops-math类型转换算子-昇腾NPU上fp16和bf16怎么互转才不拖后腿
混合精度训练在每层要做 4-6 次类型转换:fp16 计算 → fp32 归一化 → fp16 存储。推理也有 bf16 权重转 fp16 计算的场景。类型转换(Cast)的频率比大部分人的直觉高得多,ops-math 的 Cast 实现如果不够快,混合精度的收益会被转换开销吃掉。
Cast 的两种实现路径
路径 1:DMA 引擎转换。 昇腾NPU的 DMA 引擎在搬运数据时可以直接做格式转换。数据从源地址搬到目标地址的过程中,硬件自动完成 fp16→fp32 或 fp32→fp16 的转换。不占 Vector 单元,不占 Cube 单元。
路径 2:Vector 单元转换。 当数据不是连续存储时(比如经过 transpose 或 stride 操作),DMA 无法直接转换,需要 Vector 单元逐元素处理。
连续 tensor:
DMA 搬运 + 转换,带宽利用率 95%+
非连续 tensor:
DMA 搬运 → Vector 转换 → DMA 写回,带宽利用率 30-40%
3 倍的带宽差距。 这就是为什么 ops-math 的文档反复强调"确保 tensor 是 contiguous 的"。
fp16 vs bf16:选哪个
昇腾NPU对 fp16 和 bf16 都有原生支持。两者在 Cast 时的开销一样,选择取决于你的模型:
| 精度 | 范围 | 精度 | 适用场景 |
|---|---|---|---|
| fp16 | ±65504 | 约 3 位十进制 | 推理为主,训练需 loss scaling |
| bf16 | ±3.4×10^38 | 约 2 位十进制 | 训练为主,不需要 loss scaling |
| fp32 | ±3.4×10^38 | 约 7 位十进制 | 归一化、loss 计算 |
bf16 的范围跟 fp32 一样大,训练时不容易溢出,不需要 loss scaling。但精度比 fp16 低 1 位——推理时 bf16 的输出质量可能略差于 fp16。
在昇腾NPU上,Llama 系列模型推荐 fp16 推理(社区权重大多是 fp16),训练推荐 bf16。
Cast 在 Attention 里的位置
FlashAttention 内部有两次 Cast:
1. Q·K^T 的结果(fp16)→ Cast → fp32 → Softmax → Cast → fp16
Softmax 的指数运算需要 fp32 精度,fp16 会溢出
2. Attention 输出(fp16)→ 直接传给下一个算子,不需要 Cast
这两次 Cast 在 FlashAttention 融合 kernel 内部完成,不走独立的 Cast 算子。数据在 Vector 单元的 local buffer 里完成 fp16↔fp32 转换,不需要 HBM 读写。
如果你用的是标准 Attention(非融合),每次 Softmax 前后的 Cast 是独立 kernel,每次约 0.02ms。32 层 × 2 次 = 0.64ms 的额外延迟。融合后这部分为零。
混合精度训练的 Cast 热点
训练时每层的 Cast 次数:
1. MatMul 输入 fp16→fp32(CUBLAS/torch_npu 自动处理,对用户透明)
2. MatMul 输出 fp32→fp16(同上)
3. LayerNorm/RMSNorm:输入 fp16→fp32,输出 fp32→fp16
4. Loss 计算:fp16→fp32
5. 梯度更新:fp32→fp16
步骤 3 的 Cast 最频繁(每层 2 次,32 层 64 次),但也最容易被融合消除。ops-nn 的 fused_linear_act_norm 接口把 Linear + SiLU + LayerNorm 融合后,LayerNorm 前后的 Cast 都在 kernel 内部完成。
性能数据
单独测 Cast 的延迟(Atlas 800I A2,[1, 4096, 4096] tensor):
| 类型 | 连续 tensor | 非连续 tensor |
|---|---|---|
| fp16→fp32 | 0.018ms | 0.052ms |
| fp32→fp16 | 0.018ms | 0.052ms |
| fp16→bf16 | 0.020ms | 0.058ms |
| bf16→fp16 | 0.020ms | 0.058ms |
非连续 tensor 的 Cast 慢 3 倍。如果你在模型里做了 reshape、permute、stride 等操作,记得加 .contiguous() 再做类型转换。
Cast 不是性能主角,但它无处不在。混合精度训练的 Cast 开销通常占总训练时间的 2-3%,看起来不多,但全部消除后相当于免费多了一张卡 3% 的算力。方法是:走融合算子路径 + 确保 tensor contiguous。仓库在这里:
https://atomgit.com/cann/ops-math
更多推荐


所有评论(0)