FlashAttention输出全是NaN?数值问题排查指南

某团队在昇腾NPU上部署Llama-2-7B,用FlashAttention做推理。模型权重转换完成后,他们跑了一个简单测试:输入"Hello, world!",看模型能不能正常输出。结果输出的全是NaN(Not a Number)。

他们排查了模型权重——权重没问题,都是正常的数值。他们又排查了输入——输入也没问题,tokenization正确。他们最后发现,问题出在FlashAttention的数值范围上。

原来他们的模型是FP32训练后转FP16的,权重精度没问题,但KV Cache在推理时是FP16计算。FP16的动态范围只有65504,Softmax的指数运算在某些情况下会溢出,变成NaN。

这是一个典型的FlashAttention数值问题。今天把常见的三类数值问题——NaN/Inf、精度退化、注意力分布异常——逐个拆解,给出排查代码和修复方案。

先打个比方:计算器超载了

想象一个人用计算器算一个很大的数:e1000。普通计算器能显示的最大数是999999999,超过就报错了。高级计算器能显示科学计数法,比如1.0e1000。但如果你用的是那种"溢出就变乱码"的计算器,e1000的结果会变成一串随机数字——不是报错,而是乱算。

FP16就是那种"溢出就变乱码"的计算器。FlashAttention在昇腾NPU上默认用FP16做计算,如果数值超了范围,结果会变成NaN或Inf,而不是报错。

问题1:NaN和Inf从哪来?

FlashAttention产生NaN/Inf,通常有三个原因:

原因A:QK^T的值太大,Softmax溢出

Softmax(S) = exp(S - max(S)) / Σ exp(S - max(S))

问题:FP16的最大值是65504
如果 S - max(S) > ln(65504) ≈ 11.09,exp()就溢出了

当seq_len很大(比如16384)时,QK^T的打分范围会很大。如果不做数值稳定化处理,Softmax必然溢出。

# 标准Attention(不做数值稳定化)——会溢出
attn_scores = torch.matmul(q, k.transpose(-2, -1))  # shape: [B, H, S, S]
attn_probs = F.softmax(attn_scores, dim=-1)
# 当S=16384时,attn_scores的值可能超过11.09,softmax溢出

# FlashAttention(在线Softmax,自带数值稳定化)——不会溢出
# 但如果scale_value设错了,理论上也有风险
output = npu_flash_attention(
    q, k, v,
    head_num=32,
    scale_value=1.0 / (head_dim ** 0.5)  # scale_value必须正确
)

⚠️ 踩坑预警:scale_value如果设错了(比如用了1/sqrt(seq_len)而不是1/sqrt(head_dim)),会导致QK^T的值缩放不对。设大了会溢出,设小了会导致Softmax退化成one-hot分布,梯度消失。

原因B:KV Cache的数值漂移

推理时,KV Cache在FP16下不断累积。如果每次更新KV Cache时有一点误差,累积几千步之后,误差会变得很大,最终变成NaN。

# 问题代码:FP16累加误差
k_cache = torch.zeros(B, H, max_seq_len, head_dim, dtype=torch.float16, device='npu')

for i in range(num_steps):
    k_new = compute_new_k()  # 每次有一点浮点误差
    k_cache[:, :, i:i+1, :] = k_new  # 误差累积

# 1000步之后,k_cache里可能已经有显著的数值漂移了

原因C:位置编码注入的值太大

RoPE(旋转位置编码)在长序列时会让Q和K的值振荡。如果序列太长,振荡的幅度会叠加,导致某些位置的QK^T值特别大。

问题2:怎么排查NaN和Inf?

方法1:逐层打印激活值

import torch
import torch.nn.functional as F

def check_nan_inf(tensor, name="tensor"):
    """检查tensor里有没有NaN或Inf"""
    has_nan = torch.isnan(tensor).any().item()
    has_inf = torch.isinf(tensor).any().item()
    
    if has_nan or has_inf:
        print(f"❌ {name} 包含NaN={has_nan}, Inf={has_inf}")
        print(f"   shape={tensor.shape}, dtype={tensor.dtype}")
        print(f"   min={tensor.min().item():.4f}, max={tensor.max().item():.4f}")
        print(f"   mean={tensor.mean().item():.4f}, std={tensor.std().item():.4f}")
        return False
    else:
        print(f"✅ {name} 无NaN/Inf")
        return True

def check_flash_attention_nan(q, k, v, head_num):
    """检查FlashAttention各环节是否有NaN"""
    print("\n=== FlashAttention数值检查 ===")
    
    # Step 1: 检查QKV输入
    check_nan_inf(q, "Q")
    check_nan_inf(k, "K")
    check_nan_inf(v, "V")
    
    # Step 2: 检查QK^T
    qk = torch.matmul(q, k.transpose(-2, -1))
    check_nan_inf(qk, "QK^T")
    print(f"   QK^T 值范围:[{qk.min().item():.4f}, {qk.max().item():.4f}]")
    
    # Step 3: 检查Softmax
    scale = 1.0 / (q.shape[-1] ** 0.5)
    scaled_qk = qk * scale
    attn = F.softmax(scaled_qk, dim=-1)
    check_nan_inf(attn, "Softmax输出")
    
    # Step 4: 检查FlashAttention输出
    flash_out = npu_flash_attention(q, k, v, head_num=head_num, scale_value=scale)
    check_nan_inf(flash_out, "FlashAttention输出")
    
    return check_nan_inf(flash_out, "最终结果")

# 测试
q = torch.randn(1, 32, 4096, 128, device='npu', dtype=torch.float16)
k = torch.randn(1, 32, 4096, 128, device='npu', dtype=torch.float16)
v = torch.randn(1, 32, 4096, 128, device='npu', dtype=torch.float16)

check_flash_attention_nan(q, k, v, head_num=32)

方法2:检查注意力分布

正常情况下,Attention的输出是一个概率分布(每行和为1)。如果Attention退化(所有值集中在某一个位置),说明有数值问题。

def analyze_attention_distribution(attn_weights, name="Attention"):
    """分析注意力分布是否正常"""
    # attn_weights: [B, H, S, S]
    
    # 每行的和(应该是1)
    row_sums = attn_weights.sum(dim=-1)  # [B, H, S]
    print(f"{name} 行和:min={row_sums.min().item():.6f}, max={row_sums.max().item():.6f}")
    
    # 每行的熵(正常应该>0,退化会趋近0)
    entropy = -(attn_weights * torch.log(attn_weights + 1e-10)).sum(dim=-1)
    print(f"{name} 熵:mean={entropy.mean().item():.4f}, min={entropy.min().item():.4f}")
    
    # 每行的最大值(正常应该<0.5,退化会接近1)
    max_probs = attn_weights.amax(dim=-1)
    print(f"{name} 最大概率:mean={max_probs.mean().item():.4f}, max={max_probs.max().item():.4f}")
    
    # 判断是否退化
    if entropy.mean().item() < 0.1:
        print("⚠️ Attention熵极低,可能退化成one-hot分布!")
        return False
    
    if max_probs.mean().item() > 0.95:
        print("⚠️ Attention最大概率极高,可能退化成one-hot分布!")
        return False
    
    return True

# 用标准Attention做参考
attn_ref = F.softmax(torch.randn(1, 32, 512, 512, device='npu', dtype=torch.float16), dim=-1)
analyze_attention_distribution(attn_ref, "标准Attention(参考)")

方法3:对比标准Attention和FlashAttention

def compare_attention_output(q, k, v, head_num, rtol=1e-3, atol=1e-3):
    """对比标准Attention和FlashAttention的输出"""
    
    # 标准Attention(ground truth)
    with torch.no_grad():
        std_q = q.float()  # FP32,避免精度问题
        std_k = k.float()
        std_v = v.float()
        
        scale = 1.0 / (q.shape[-1] ** 0.5)
        std_scores = torch.matmul(std_q, std_k.transpose(-2, -1)) * scale
        std_attn = F.softmax(std_scores, dim=-1)
        std_out = torch.matmul(std_attn, std_v)
    
    # FlashAttention
    flash_out = npu_flash_attention(q, k, v, head_num=head_num).float()
    
    # 对比
    diff = (std_out - flash_out).abs()
    max_diff = diff.max().item()
    mean_diff = diff.mean().item()
    
    print(f"\n=== 标准Attention vs FlashAttention ===")
    print(f"最大误差:{max_diff:.6f}")
    print(f"平均误差:{mean_diff:.6f}")
    print(f"容限:rtol={rtol}, atol={atol}")
    
    within_tolerance = torch.allclose(std_out, flash_out, rtol=rtol, atol=atol)
    if within_tolerance:
        print("✅ 误差在容限范围内,FlashAttention正常工作")
    else:
        print("❌ 误差超出容限,FlashAttention有问题")
    
    return within_tolerance

问题3:怎么修复NaN和Inf?

修复1:改用BF16

BF16的动态范围比FP16大得多(10^38级别,不会溢出),如果昇腾NPU支持BF16,切换到BF16可以解决大部分数值问题。

# 用BF16跑FlashAttention
model = model.bfloat16()  # 注意:不是.half()

q = q.bfloat16()
k = k.bfloat16()
v = v.bfloat16()

output = npu_flash_attention(
    q, k, v,
    head_num=32,
    scale_value=1.0 / (head_dim ** 0.5),
    # 注意:需要确认算子支持BF16
)

⚠️ 踩坑预警:BF16的精度比FP16低(7位有效数字 vs 10位)。对于需要高精度的场景(比如训练),BF16可能不够。对于推理,BF16通常足够。

修复2:启用Softmax的在线稳定化

ops-transformer的FlashAttention默认启用了在线稳定化(就是减max的技巧)。如果用的是自定义实现,确保加上这个技巧:

def stable_flash_attention(q, k, v, scale):
    """数值稳定的FlashAttention"""
    # 减max,避免溢出
    m = q.shape[2]  # seq_len
    n = k.shape[2]  # seq_len
    
    # 在线Softmax
    m_i = torch.full((q.shape[0], q.shape[1], q.shape[2], 1), 
                      float('-inf'), device=q.device, dtype=q.dtype)
    l_i = torch.zeros(q.shape[0], q.shape[1], q.shape[2], 1, device=q.device, dtype=q.dtype)
    o = torch.zeros_like(q)
    
    # 分块处理(这里省略具体实现)
    # ...
    
    return o

修复3:减小RoPE的频率范围(针对长序列问题)

RoPE在长序列时的振荡问题,可以减小base频率来缓解:

# 原始RoPE base=10000
# 长序列下改为更小的base
class ScaledRoPE(torch.nn.Module):
    def __init__(self, dim, max_seq_len=4096, base=10000):
        super().__init__()
        # 减小base,降低高频振荡
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
    
    def forward(self, x, seq_len):
        # 生成位置编码
        t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        return emb

总结:数值问题排查清单

FlashAttention输出NaN,按这个清单查:

检查步骤 操作 判断标准
1. 查QKV输入 逐层打印tensor 有NaN→模型权重问题
2. 查scale_value 确认是1/sqrt(head_dim) 设错了→数值溢出
3. 查dtype 确认是FP16还是BF16 FP16且seq_len大→试试BF16
4. 查注意力分布 算熵和最大概率 熵<0.1或最大概率>0.95→退化
5. 对比标准Attention 算误差 误差>1e-2→实现有问题
6. 查RoPE频率 看位置编码的值范围 太大→减小base

代码和文档:

https://atomgit.com/cann/ops-transformer

Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐