【昇腾CANN训练营·进阶篇】精度侦探:使用PyTorch Hook与溢出检测工具定位数值异常
精度调试是对开发者耐心的终极考验。宏观:利用 Loss 曲线判断发散趋势。中观:利用 PyTorch Hook 快速定位出现 NaN 的层级。微观:利用 NPU 硬件检测和 Dump 工具,揪出算子内部的溢出。基准:始终以 FP32 的 CPU/GPU 结果为金标准(Golden Data)。掌握了这套侦探技能,你就不会再被莫名其妙的 Loss NaN 吓倒,而是能冷静地找出真凶。
训练营简介
2025年昇腾CANN训练营第二季,基于CANN开源开放全场景,推出0基础入门系列、码力全开特辑、开发者案例等专题课程,助力不同阶段开发者快速提升算子开发技能。获得Ascend C算子中级认证,即可领取精美证书,完成社区任务更有机会赢取华为手机,平板、开发板等大奖。
报名链接:https://www.hiascend.com/developer/activities/cann20252#cann-camp-2502-intro

前言
在 AI 开发中,有一种比“跑得慢”更可怕的情况,那就是**“算不对”**。
特别是当我们从 FP32 迁移到 FP16/BF16 混合精度训练时,由于动态范围缩窄,极易出现:
-
Overflow (上溢):数值超过 65504,变成
INF。 -
Underflow (下溢):数值太小被截断为 0,导致梯度消失。
-
NaN (Not a Number):
INF - INF或0 / 0,一旦出现,瞬间传染整个网络。
在几百层的 Transformer 中,找出一个异常数值如同大海捞针。本期文章将教你两套“侦探工具”:软件探针 (Hooks) 和 硬件陷阱 (Overflow Check)。
一、 核心图解:捕捉幽灵数值
精度异常通常是“幽灵”般的——它可能只在第 1000 次迭代的某一个算子中闪现一下,然后迅速破坏后续的所有计算。

二、 软件探针:PyTorch Hook 实战
PyTorch 提供了 register_forward_hook 和 register_backward_hook,允许我们在不修改模型源码的情况下,监控每一层 Layer 的输入输出。
2.1 编写“NaN 杀手” Hook
我们需要一个 Hook 函数,它能检查 Tensor 中是否包含 NaN 或 Inf。
import torch
def check_numerics_hook(module, inputs, outputs):
# 检查输出是否异常
# outputs 可能是 Tensor 或 Tuple
if isinstance(outputs, torch.Tensor):
tensors = [outputs]
else:
tensors = outputs
for i, t in enumerate(tensors):
if torch.isnan(t).any() or torch.isinf(t).any():
print(f" [Alert] Found NaN/Inf in module: {module.__class__.__name__}")
print(f" - Output index: {i}")
print(f" - Max: {t.max()}, Min: {t.min()}")
# 可以在这里 dump 数据以便后续分析
# torch.save(t, f"debug_{module.__class__.__name__}_out.pt")
# 激进策略:直接报错停止
raise ValueError("Numerical explosion detected!")
# 注册到模型的所有子模块
def register_hooks(model):
for name, layer in model.named_modules():
layer.register_forward_hook(check_numerics_hook)
使用方法:
model = MyLLM().npu()
register_hooks(model) # 注入探针
# 正常训练... 一旦出现 NaN,程序会立刻抛出异常并定位到具体 Layer
output = model(input)
三、 硬件陷阱:NPU 溢出检测
软件 Hook 虽然灵活,但有性能开销,且只能检测 Layer 边界。如果算子内部溢出(比如 Exp 的中间结果),Hook 是看不到的。
昇腾 AI Core 内置了 浮点状态寄存器,可以记录计算过程中是否发生了溢出。我们可以通过 ACL 配置来开启这个“硬件陷阱”。
3.1 开启溢出检测 (PyTorch)
在 torch_npu 中,我们可以配置 NPU 配置项 来开启检测。
import torch
import torch_npu
# 开启溢出检测模式
torch_npu.npu.set_compile_option(
jit_compile=False, # 建议关闭 JIT 以便更准确地定位
overflow_check=True
)
# 训练循环
try:
loss.backward()
# 在 Step 结束时检查是否溢出
if torch_npu.npu.get_npu_overflow_flag():
print("NPU Overflow detected in this step! Skipping update.")
optimizer.zero_grad() # 丢弃本次更新,防止权重被污染
# 可选:降低 Learning Rate 或调整 Loss Scale
else:
optimizer.step()
except RuntimeError as e:
print(f"Runtime Error: {e}")
3.2 进阶:定位到具体算子
如果开启全局检测后发现有溢出,如何知道是哪个算子干的? 昇腾提供了 Dump 功能。
-
配置
acl.json(或通过环境变量):{ "dump": { "dump_path": "./dump_data", "dump_mode": "all", // dump 所有算子 "dump_op_switch": "on" } } -
运行训练。
-
使用 msprof 或 MindStudio 解析 Dump 数据,工具会标记出 status 异常的算子。
四、 精度对齐:与 CPU/GPU 标杆对比
有时候结果没溢出,只是精度差(比如 NPU 算出来是 3.5,GPU 是 3.9)。这就需要做逐层比对。
核心思路:
-
固定随机种子(Seed),保证初始化参数一致。
-
保存同一份 Input 数据。
-
分别在 CPU/GPU 和 NPU 上运行模型,并 Hook 每一层的输出。
-
计算 Cosine Similarity 或 Max Diff。
工具推荐: 昇腾官方提供了 Pytorch Model Accuracy Analyzer 工具(通常集成在 MST (MindStudio Toolkit) 中),可以自动完成上述的比对流程,并生成 Excel 报告,标红误差超过阈值(如 1e-3)的层。
五、 总结
精度调试是对开发者耐心的终极考验。
-
宏观:利用 Loss 曲线判断发散趋势。
-
中观:利用 PyTorch Hook 快速定位出现 NaN 的层级。
-
微观:利用 NPU 硬件检测和 Dump 工具,揪出算子内部的溢出。
-
基准:始终以 FP32 的 CPU/GPU 结果为金标准(Golden Data)。
掌握了这套侦探技能,你就不会再被莫名其妙的 Loss NaN 吓倒,而是能冷静地找出真凶。
更多推荐





所有评论(0)