FlashAttention输出全是NaN?数值问题排查指南
某团队在昇腾NPU上部署Llama-2-7B,用FlashAttention做推理。模型权重转换完成后,他们跑了一个简单测试:输入"Hello, world!",看模型能不能正常输出。结果输出的全是NaN(Not a Number)。他们排查了模型权重——权重没问题,都是正常的数值。他们又排查了输入——输入也没问题,tokenization正确。他们最后发现,问题出在FlashAttention的
FlashAttention输出全是NaN?数值问题排查指南
某团队在昇腾NPU上部署Llama-2-7B,用FlashAttention做推理。模型权重转换完成后,他们跑了一个简单测试:输入"Hello, world!",看模型能不能正常输出。结果输出的全是NaN(Not a Number)。
他们排查了模型权重——权重没问题,都是正常的数值。他们又排查了输入——输入也没问题,tokenization正确。他们最后发现,问题出在FlashAttention的数值范围上。
原来他们的模型是FP32训练后转FP16的,权重精度没问题,但KV Cache在推理时是FP16计算。FP16的动态范围只有65504,Softmax的指数运算在某些情况下会溢出,变成NaN。
这是一个典型的FlashAttention数值问题。今天把常见的三类数值问题——NaN/Inf、精度退化、注意力分布异常——逐个拆解,给出排查代码和修复方案。
先打个比方:计算器超载了
想象一个人用计算器算一个很大的数:e1000。普通计算器能显示的最大数是999999999,超过就报错了。高级计算器能显示科学计数法,比如1.0e1000。但如果你用的是那种"溢出就变乱码"的计算器,e1000的结果会变成一串随机数字——不是报错,而是乱算。
FP16就是那种"溢出就变乱码"的计算器。FlashAttention在昇腾NPU上默认用FP16做计算,如果数值超了范围,结果会变成NaN或Inf,而不是报错。
问题1:NaN和Inf从哪来?
FlashAttention产生NaN/Inf,通常有三个原因:
原因A:QK^T的值太大,Softmax溢出
Softmax(S) = exp(S - max(S)) / Σ exp(S - max(S))
问题:FP16的最大值是65504
如果 S - max(S) > ln(65504) ≈ 11.09,exp()就溢出了
当seq_len很大(比如16384)时,QK^T的打分范围会很大。如果不做数值稳定化处理,Softmax必然溢出。
# 标准Attention(不做数值稳定化)——会溢出
attn_scores = torch.matmul(q, k.transpose(-2, -1)) # shape: [B, H, S, S]
attn_probs = F.softmax(attn_scores, dim=-1)
# 当S=16384时,attn_scores的值可能超过11.09,softmax溢出
# FlashAttention(在线Softmax,自带数值稳定化)——不会溢出
# 但如果scale_value设错了,理论上也有风险
output = npu_flash_attention(
q, k, v,
head_num=32,
scale_value=1.0 / (head_dim ** 0.5) # scale_value必须正确
)
⚠️ 踩坑预警:scale_value如果设错了(比如用了1/sqrt(seq_len)而不是1/sqrt(head_dim)),会导致QK^T的值缩放不对。设大了会溢出,设小了会导致Softmax退化成one-hot分布,梯度消失。
原因B:KV Cache的数值漂移
推理时,KV Cache在FP16下不断累积。如果每次更新KV Cache时有一点误差,累积几千步之后,误差会变得很大,最终变成NaN。
# 问题代码:FP16累加误差
k_cache = torch.zeros(B, H, max_seq_len, head_dim, dtype=torch.float16, device='npu')
for i in range(num_steps):
k_new = compute_new_k() # 每次有一点浮点误差
k_cache[:, :, i:i+1, :] = k_new # 误差累积
# 1000步之后,k_cache里可能已经有显著的数值漂移了
原因C:位置编码注入的值太大
RoPE(旋转位置编码)在长序列时会让Q和K的值振荡。如果序列太长,振荡的幅度会叠加,导致某些位置的QK^T值特别大。
问题2:怎么排查NaN和Inf?
方法1:逐层打印激活值
import torch
import torch.nn.functional as F
def check_nan_inf(tensor, name="tensor"):
"""检查tensor里有没有NaN或Inf"""
has_nan = torch.isnan(tensor).any().item()
has_inf = torch.isinf(tensor).any().item()
if has_nan or has_inf:
print(f"❌ {name} 包含NaN={has_nan}, Inf={has_inf}")
print(f" shape={tensor.shape}, dtype={tensor.dtype}")
print(f" min={tensor.min().item():.4f}, max={tensor.max().item():.4f}")
print(f" mean={tensor.mean().item():.4f}, std={tensor.std().item():.4f}")
return False
else:
print(f"✅ {name} 无NaN/Inf")
return True
def check_flash_attention_nan(q, k, v, head_num):
"""检查FlashAttention各环节是否有NaN"""
print("\n=== FlashAttention数值检查 ===")
# Step 1: 检查QKV输入
check_nan_inf(q, "Q")
check_nan_inf(k, "K")
check_nan_inf(v, "V")
# Step 2: 检查QK^T
qk = torch.matmul(q, k.transpose(-2, -1))
check_nan_inf(qk, "QK^T")
print(f" QK^T 值范围:[{qk.min().item():.4f}, {qk.max().item():.4f}]")
# Step 3: 检查Softmax
scale = 1.0 / (q.shape[-1] ** 0.5)
scaled_qk = qk * scale
attn = F.softmax(scaled_qk, dim=-1)
check_nan_inf(attn, "Softmax输出")
# Step 4: 检查FlashAttention输出
flash_out = npu_flash_attention(q, k, v, head_num=head_num, scale_value=scale)
check_nan_inf(flash_out, "FlashAttention输出")
return check_nan_inf(flash_out, "最终结果")
# 测试
q = torch.randn(1, 32, 4096, 128, device='npu', dtype=torch.float16)
k = torch.randn(1, 32, 4096, 128, device='npu', dtype=torch.float16)
v = torch.randn(1, 32, 4096, 128, device='npu', dtype=torch.float16)
check_flash_attention_nan(q, k, v, head_num=32)
方法2:检查注意力分布
正常情况下,Attention的输出是一个概率分布(每行和为1)。如果Attention退化(所有值集中在某一个位置),说明有数值问题。
def analyze_attention_distribution(attn_weights, name="Attention"):
"""分析注意力分布是否正常"""
# attn_weights: [B, H, S, S]
# 每行的和(应该是1)
row_sums = attn_weights.sum(dim=-1) # [B, H, S]
print(f"{name} 行和:min={row_sums.min().item():.6f}, max={row_sums.max().item():.6f}")
# 每行的熵(正常应该>0,退化会趋近0)
entropy = -(attn_weights * torch.log(attn_weights + 1e-10)).sum(dim=-1)
print(f"{name} 熵:mean={entropy.mean().item():.4f}, min={entropy.min().item():.4f}")
# 每行的最大值(正常应该<0.5,退化会接近1)
max_probs = attn_weights.amax(dim=-1)
print(f"{name} 最大概率:mean={max_probs.mean().item():.4f}, max={max_probs.max().item():.4f}")
# 判断是否退化
if entropy.mean().item() < 0.1:
print("⚠️ Attention熵极低,可能退化成one-hot分布!")
return False
if max_probs.mean().item() > 0.95:
print("⚠️ Attention最大概率极高,可能退化成one-hot分布!")
return False
return True
# 用标准Attention做参考
attn_ref = F.softmax(torch.randn(1, 32, 512, 512, device='npu', dtype=torch.float16), dim=-1)
analyze_attention_distribution(attn_ref, "标准Attention(参考)")
方法3:对比标准Attention和FlashAttention
def compare_attention_output(q, k, v, head_num, rtol=1e-3, atol=1e-3):
"""对比标准Attention和FlashAttention的输出"""
# 标准Attention(ground truth)
with torch.no_grad():
std_q = q.float() # FP32,避免精度问题
std_k = k.float()
std_v = v.float()
scale = 1.0 / (q.shape[-1] ** 0.5)
std_scores = torch.matmul(std_q, std_k.transpose(-2, -1)) * scale
std_attn = F.softmax(std_scores, dim=-1)
std_out = torch.matmul(std_attn, std_v)
# FlashAttention
flash_out = npu_flash_attention(q, k, v, head_num=head_num).float()
# 对比
diff = (std_out - flash_out).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
print(f"\n=== 标准Attention vs FlashAttention ===")
print(f"最大误差:{max_diff:.6f}")
print(f"平均误差:{mean_diff:.6f}")
print(f"容限:rtol={rtol}, atol={atol}")
within_tolerance = torch.allclose(std_out, flash_out, rtol=rtol, atol=atol)
if within_tolerance:
print("✅ 误差在容限范围内,FlashAttention正常工作")
else:
print("❌ 误差超出容限,FlashAttention有问题")
return within_tolerance
问题3:怎么修复NaN和Inf?
修复1:改用BF16
BF16的动态范围比FP16大得多(10^38级别,不会溢出),如果昇腾NPU支持BF16,切换到BF16可以解决大部分数值问题。
# 用BF16跑FlashAttention
model = model.bfloat16() # 注意:不是.half()
q = q.bfloat16()
k = k.bfloat16()
v = v.bfloat16()
output = npu_flash_attention(
q, k, v,
head_num=32,
scale_value=1.0 / (head_dim ** 0.5),
# 注意:需要确认算子支持BF16
)
⚠️ 踩坑预警:BF16的精度比FP16低(7位有效数字 vs 10位)。对于需要高精度的场景(比如训练),BF16可能不够。对于推理,BF16通常足够。
修复2:启用Softmax的在线稳定化
ops-transformer的FlashAttention默认启用了在线稳定化(就是减max的技巧)。如果用的是自定义实现,确保加上这个技巧:
def stable_flash_attention(q, k, v, scale):
"""数值稳定的FlashAttention"""
# 减max,避免溢出
m = q.shape[2] # seq_len
n = k.shape[2] # seq_len
# 在线Softmax
m_i = torch.full((q.shape[0], q.shape[1], q.shape[2], 1),
float('-inf'), device=q.device, dtype=q.dtype)
l_i = torch.zeros(q.shape[0], q.shape[1], q.shape[2], 1, device=q.device, dtype=q.dtype)
o = torch.zeros_like(q)
# 分块处理(这里省略具体实现)
# ...
return o
修复3:减小RoPE的频率范围(针对长序列问题)
RoPE在长序列时的振荡问题,可以减小base频率来缓解:
# 原始RoPE base=10000
# 长序列下改为更小的base
class ScaledRoPE(torch.nn.Module):
def __init__(self, dim, max_seq_len=4096, base=10000):
super().__init__()
# 减小base,降低高频振荡
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, x, seq_len):
# 生成位置编码
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb
总结:数值问题排查清单
FlashAttention输出NaN,按这个清单查:
| 检查步骤 | 操作 | 判断标准 |
|---|---|---|
| 1. 查QKV输入 | 逐层打印tensor | 有NaN→模型权重问题 |
| 2. 查scale_value | 确认是1/sqrt(head_dim) | 设错了→数值溢出 |
| 3. 查dtype | 确认是FP16还是BF16 | FP16且seq_len大→试试BF16 |
| 4. 查注意力分布 | 算熵和最大概率 | 熵<0.1或最大概率>0.95→退化 |
| 5. 对比标准Attention | 算误差 | 误差>1e-2→实现有问题 |
| 6. 查RoPE频率 | 看位置编码的值范围 | 太大→减小base |
代码和文档:
https://atomgit.com/cann/ops-transformer
更多推荐




所有评论(0)