某团队在昇腾NPU上实现了FlashAttention,测试时发现大多数case都正确,但有几个边界case输出了错误结果——Loss只差一点点,但长序列场景下输出差异明显。他们怀疑是数值精度问题,也怀疑是算法实现有bug。

问题是:没有系统性地验证FlashAttention的正确性。测试几个case不能证明算法是对的,需要形式化验证和边界条件覆盖。

今天把FlashAttention的正确性验证方法讲清楚,给出完整的测试框架和形式化验证思路。

为什么FlashAttention的正确性需要专门验证

标准Attention vs FlashAttention的数学等价性

标准Attention:
  S[i,j] = Q[i] · K[j] / √d
  P[i,j] = exp(S[i,j]) / Σexp(S[i,*])
  O[i] = Σj P[i,j] · V[j]

FlashAttention(分块计算):
  把序列分成T个blocks
  每个block Ti计算:
    mi = rowmax(Si,*), li = rowsum(exp(Si,* - mi))
  最终合并:
    O = Σi exp(mi - m) / L · Oi
    L = Σi exp(mi - m) / li

数学上要求:
  Σi exp(mi - m) / L · Oi = Σj P[i,j] · V[j]
  
这个等式的成立需要:
  1. 每个block的m和l计算正确
  2. 合并公式的数学推导正确
  3. 数值稳定性处理(exp溢出)正确
  4. 所有block的合并顺序正确

数值等价性验证

分层测试框架

import torch
import numpy as np
from typing import List, Tuple

class FlashAttentionVerifier:
    """
    FlashAttention正确性验证器
    
    验证策略:
      1. 数值等价性:跟标准Attention对比输出
      2. 数学性质:验证softmax的数学性质(和为1、非负等)
      3. 边界条件:空序列、单token、极端值
      4. 回归测试:确保改动不引入新的bug
    """
    
    def __init__(self, rtol=1e-3, atol=1e-5):
        self.rtol = rtol  # 相对误差容限
        self.atol = atol  # 绝对误差容限
    
    def standard_attention(self, q, k, v):
        """
        标准Attention实现(参考实现)
        """
        d = q.shape[-1]
        scale = 1.0 / (d ** 0.5)
        
        # S = Q @ K^T
        scores = torch.matmul(q, k.transpose(-2, -1)) * scale
        
        # P = softmax(S)
        # 数值稳定的softmax
        scores_max = scores.amax(dim=-1, keepdim=True)
        scores_exp = torch.exp(scores - scores_max)
        P = scores_exp / scores_exp.sum(dim=-1, keepdim=True)
        
        # O = P @ V
        O = torch.matmul(P, v)
        
        return O, P
    
    def flash_attention(self, q, k, v, block_size=128):
        """
        FlashAttention实现(待验证)
        """
        
        B, H, S, D = q.shape
        
        # 初始化
        O = torch.zeros_like(q)
        m = torch.full((B, H, S, 1), -float('inf'), device=q.device)
        l = torch.zeros((B, H, S, 1), device=q.device)
        
        # 分block计算
        num_blocks = (S + block_size - 1) // block_size
        
        for i in range(num_blocks):
            start_i = i * block_size
            end_i = min(start_i + block_size, S)
            
            # 读取block_i的K、V
            k_i = k[:, :, start_i:end_i, :]
            v_i = v[:, :, start_i:end_i, :]
            
            # 计算block_i的m和l
            scale = 1.0 / (D ** 0.5)
            scores_i = torch.matmul(q, k_i.transpose(-2, -1)) * scale
            
            m_i = scores_i.amax(dim=-1, keepdim=True)
            p_i = torch.exp(scores_i - m_i)
            l_i = p_i.sum(dim=-1, keepdim=True)
            
            # 合并
            m_new = torch.maximum(m[:, :, start_i:end_i], m_i)
            
            # 防止m全为-inf
            m_safe = torch.where(torch.isinf(m_new), torch.zeros_like(m_new), m_new)
            
            p_i_scaled = torch.exp(m_i - m_safe)
            l_i_scaled = p_i_scaled * l_i
            
            # 旧的贡献
            old_scale = torch.exp(m[:, :, start_i:end_i] - m_safe)
            old_scale = torch.where(torch.isinf(old_scale), torch.zeros_like(old_scale), old_scale)
            O_old = O[:, :, start_i:end_i] * old_scale
            l_old = l[:, :, start_i:end_i] * old_scale
            
            # 新的贡献
            O_new = torch.matmul(p_i_scaled / (l_i_scaled + 1e-10), v_i)
            
            # 更新
            l_combined = l_i_scaled + l_old
            O[:, :, start_i:end_i] = (O_old + O_new) / (l_combined + 1e-10)
            m[:, :, start_i:end_i] = m_safe
            l[:, :, start_i:end_i] = l_combined
        
        return O, None
    
    def verify_numerical_equivalence(self, q, k, v):
        """
        验证:FlashAttention输出跟标准Attention等价
        """
        
        O_ref, P_ref = self.standard_attention(q, k, v)
        O_fa, _ = self.flash_attention(q, k, v)
        
        # 计算误差
        abs_diff = (O_ref - O_fa).abs()
        rel_diff = abs_diff / (O_ref.abs() + 1e-8)
        
        max_abs_diff = abs_diff.max().item()
        max_rel_diff = rel_diff.max().item()
        mean_rel_diff = rel_diff.mean().item()
        
        passed = (max_rel_diff < self.rtol or max_abs_diff < self.atol)
        
        print(f"\n=== 数值等价性验证 ===")
        print(f"最大绝对误差: {max_abs_diff:.2e}")
        print(f"最大相对误差: {max_rel_diff:.2e}")
        print(f"平均相对误差: {mean_rel_diff:.2e}")
        print(f"容限: rtol={self.rtol}, atol={self.atol}")
        print(f"结果: {'✅ 通过' if passed else '❌ 失败'}")
        
        return passed, O_ref, O_fa
    
    def verify_softmax_properties(self, attn_weights):
        """
        验证softmax的数学性质
        
        性质1:每行和为1
        性质2:所有元素非负
        性质3:最大值位置概率最高
        """
        
        print(f"\n=== Softmax数学性质验证 ===")
        
        errors = []
        
        # 性质1:每行和为1
        row_sums = attn_weights.sum(dim=-1)
        sum_error = (row_sums - 1.0).abs().max().item()
        
        print(f"1. 行和=1: max误差={sum_error:.2e} {'✅' if sum_error < 1e-4 else '❌'}")
        if sum_error > 1e-4:
            errors.append(f"行和误差: {sum_error}")
        
        # 性质2:非负
        min_val = attn_weights.min().item()
        print(f"2. 非负: min={min_val:.2e} {'✅' if min_val >= -1e-6 else '❌'}")
        if min_val < -1e-6:
            errors.append(f"负值: {min_val}")
        
        # 性质3:最大值位置概率最高
        max_idx = attn_weights.argmax(dim=-1)
        max_val = attn_weights.amax(dim=-1)
        diag_check = torch.all(max_val > attn_weights.mean(dim=-1))
        print(f"3. 集中性: max>mean {'✅' if diag_check else '❌'}")
        
        return len(errors) == 0

边界条件测试

class BoundaryConditionTester:
    """
    边界条件测试
    
    覆盖场景:
      1. 空序列(长度为0)
      2. 单token序列
      3. 单head
      4. 极端值(无穷大、NaN、接近0)
      5. 非对齐长度(非block_size倍数)
      6. batch_size>1
    """
    
    def test_all_boundaries(self, flash_attention_fn):
        """运行所有边界测试"""
        
        results = []
        
        print("\n=== 边界条件测试 ===")
        
        # 测试1:空序列
        results.append(("空序列", self._test_empty_sequence(flash_attention_fn)))
        
        # 测试2:单token
        results.append(("单token", self._test_single_token(flash_attention_fn)))
        
        # 测试3:单head
        results.append(("单head", self._test_single_head(flash_attention_fn)))
        
        # 测试4:极端数值
        results.append(("极端数值", self._test_extreme_values(flash_attention_fn)))
        
        # 测试5:非对齐长度
        results.append(("非对齐长度", self._test_unaligned_length(flash_attention_fn)))
        
        # 测试6:大批量
        results.append(("大批量", self._test_large_batch(flash_attention_fn)))
        
        # 测试7:极长序列
        results.append(("极长序列", self._test_very_long_sequence(flash_attention_fn)))
        
        # 汇总
        print("\n=== 边界测试汇总 ===")
        all_passed = True
        for name, (passed, msg) in results:
            status = "✅" if passed else "❌"
            print(f"{status} {name}: {msg}")
            if not passed:
                all_passed = False
        
        return all_passed
    
    def _test_single_token(self, fn):
        """测试单token序列"""
        
        q = torch.randn(1, 1, 1, 128)
        k = torch.randn(1, 1, 1, 128)
        v = torch.randn(1, 1, 1, 128)
        
        try:
            O = fn(q, k, v)
            # 单token时,attention应该是全1(因为只有一个key)
            # O = V[0] * 1 = V[0]
            expected = v
            diff = (O - expected).abs().max().item()
            passed = diff < 1e-4
            return passed, f"max_diff={diff:.2e}"
        except Exception as e:
            return False, str(e)
    
    def _test_extreme_values(self, fn):
        """测试极端数值"""
        
        test_cases = [
            ("全0", torch.zeros(1, 1, 4, 128), torch.zeros(1, 1, 4, 128), torch.zeros(1, 1, 4, 128)),
            ("极大值", torch.full((1, 1, 4, 128), 1e3), torch.full((1, 1, 4, 128), 1e3), torch.ones(1, 1, 4, 128)),
            ("极小值", torch.full((1, 1, 4, 128), 1e-10), torch.full((1, 1, 4, 128), 1e-10), torch.ones(1, 1, 4, 128)),
            ("混合", torch.randn(1, 1, 4, 128) * 1e2, torch.randn(1, 1, 4, 128) * 1e2, torch.randn(1, 1, 4, 128)),
        ]
        
        for name, q, k, v in test_cases:
            try:
                O = fn(q, k, v)
                if torch.isnan(O).any() or torch.isinf(O).any():
                    return False, f"{name}产生NaN/Inf"
            except Exception as e:
                return False, f"{name}: {e}"
        
        return True, "所有极端值通过"
    
    def _test_unaligned_length(self, fn):
        """测试非block_size倍数的序列"""
        
        # block_size=128,测试129和255(都不是128的倍数)
        for seq_len in [129, 255, 300, 1000]:
            q = torch.randn(1, 1, seq_len, 128)
            k = torch.randn(1, 1, seq_len, 128)
            v = torch.randn(1, 1, seq_len, 128)
            
            try:
                O = fn(q, k, v)
                if O.shape != q.shape:
                    return False, f"seq_len={seq_len}输出shape不匹配"
            except Exception as e:
                return False, f"seq_len={seq_len}: {e}"
        
        return True, "所有非对齐长度通过"
    
    def _test_very_long_sequence(self, fn):
        """测试超长序列"""
        
        # 模拟16384长度的序列
        seq_len = 2048  # 缩小测试规模
        
        q = torch.randn(1, 1, seq_len, 128)
        k = torch.randn(1, 1, seq_len, 128)
        v = torch.randn(1, 1, seq_len, 128)
        
        try:
            O = fn(q, k, v)
            
            # 验证:最后一行的attention只看自己(因果mask)
            # 由于是自回归,实际上最后一个query只看最后一个key
            if O.shape == q.shape:
                return True, f"seq_len={seq_len}通过"
            else:
                return False, "输出shape不匹配"
        except Exception as e:
            return False, str(e)

回归测试框架

class RegressionTestSuite:
    """
    回归测试套件
    
    每次代码改动后运行,确保不引入新的bug
    """
    
    def __init__(self):
        self.test_cases = []
        self.results_history = []  # 保存历史结果用于对比
    
    def add_test_case(self, name, q, k, v, expected_O=None, rtol=1e-3, atol=1e-5):
        """添加测试用例"""
        
        self.test_cases.append({
            "name": name,
            "q": q,
            "k": k,
            "v": v,
            "expected_O": expected_O,
            "rtol": rtol,
            "atol": atol
        })
    
    def run_suite(self, flash_attention_fn, save_to_history=True):
        """运行完整测试套件"""
        
        print("\n" + "="*60)
        print("FlashAttention 回归测试套件")
        print("="*60)
        
        results = []
        failed_cases = []
        
        for i, tc in enumerate(self.test_cases):
            name = tc["name"]
            
            try:
                O_actual = flash_attention_fn(tc["q"], tc["k"], tc["v"])
                
                # 如果有expected_O,验证
                if tc["expected_O"] is not None:
                    abs_diff = (O_actual - tc["expected_O"]).abs()
                    rel_diff = abs_diff / (tc["expected_O"].abs() + 1e-8)
                    
                    max_rel = rel_diff.max().item()
                    max_abs = abs_diff.max().item()
                    
                    passed = (max_rel < tc["rtol"] or max_abs < tc["atol"])
                else:
                    # 无expected_O,只验证不报错
                    passed = not (torch.isnan(O_actual).any() or torch.isinf(O_actual).any())
                    max_rel = max_abs = 0.0
                
                status = "✅" if passed else "❌"
                print(f"{status} [{i+1}/{len(self.test_cases)}] {name}")
                if not passed:
                    print(f"   max_rel_diff={max_rel:.2e}, max_abs_diff={max_abs:.2e}")
                    failed_cases.append(name)
                
                results.append({
                    "name": name,
                    "passed": passed,
                    "max_rel": max_rel,
                    "max_abs": max_abs
                })
                
            except Exception as e:
                print(f"❌ [{i+1}/{len(self.test_cases)}] {name}: {e}")
                results.append({
                    "name": name,
                    "passed": False,
                    "error": str(e)
                })
                failed_cases.append(name)
        
        # 统计
        passed_count = sum(1 for r in results if r["passed"])
        total_count = len(results)
        
        print("\n" + "="*60)
        print(f"测试结果: {passed_count}/{total_count} 通过")
        
        if failed_cases:
            print(f"失败用例: {failed_cases}")
        else:
            print("✅ 全部通过!")
        
        if save_to_history:
            self.results_history.append({
                "timestamp": str(torch.cuda.Event(ElapsedTime=0)),  # 占位
                "results": results
            })
        
        return passed_count == total_count, results


def generate_regression_test_data():
    """生成回归测试数据"""
    
    print("\n=== 生成回归测试数据 ===")
    
    test_suite = RegressionTestSuite()
    
    # Case 1: 标准随机
    for seq_len in [4, 16, 64, 256, 1024]:
        for dtype in [torch.float16, torch.float32]:
            q = torch.randn(2, 8, seq_len, 64, dtype=dtype)
            k = torch.randn(2, 8, seq_len, 64, dtype=dtype)
            v = torch.randn(2, 8, seq_len, 64, dtype=dtype)
            
            test_suite.add_test_case(
                f"random_seq{seq_len}_{dtype}",
                q, k, v
            )
    
    # Case 2: 对称矩阵
    q = torch.randn(1, 4, 32, 64)
    k = q.clone()
    v = torch.randn(1, 4, 32, 64)
    test_suite.add_test_case("symmetric_qk", q, k, v)
    
    # Case 3: 稀疏注意力模式
    q = torch.randn(1, 4, 128, 64)
    k = torch.zeros(1, 4, 128, 64)
    k[:, :, ::4, :] = torch.randn(1, 4, 32, 64)  # 每4个token一个key
    v = torch.randn(1, 4, 128, 64)
    test_suite.add_test_case("sparse_attention", q, k, v)
    
    # Case 4: 重复token
    q = torch.randn(1, 4, 64, 64)
    k = torch.cat([q[:, :, :32, :]] * 2, dim=2)  # 前32个token重复
    v = torch.randn(1, 4, 64, 64)
    test_suite.add_test_case("repeated_keys", q, k, v)
    
    print(f"已生成 {len(test_suite.test_cases)} 个测试用例")
    
    return test_suite

形式化验证思路

def formal_verification_outline():
    """
    FlashAttention形式化验证大纲
    
    目标:数学上证明FlashAttention等价于标准Attention
    """
    
    print("\n=== FlashAttention形式化验证大纲 ===")
    
    verification_steps = [
        {
            "step": "1. 定义不变式 (Loop Invariant)",
            "content": """
在第i个block处理完成后,定义不变式I(i):
  
  I(i): 
    O[0:i] = Σ(j=0 to i-1) exp(S[0:i,j] - m_i) / L_i · V[j]
    m_i = max(S[0:i,*])
    L_i = Σ(j=0 to i-1) exp(S[0:i,j] - m_i)

其中:
  S[0:i,*] 是前i个query对所有keys的scores
  O[0:i] 是前i个query的输出
  m_i 和 L_i 是数值稳定的中间状态
            """,
            "proof_approach": "数学归纳法:证明I(0)成立,且I(i)→I(i+1)"
        },
        {
            "step": "2. 初始化 (Base Case)",
            "content": """
当i=0时,没有处理任何block:
  
  m_0 = -∞
  L_0 = 0
  O[0:0] = 空矩阵
  
此时I(0)平凡成立。
            """,
            "proof_approach": "直接代入验证"
        },
        {
            "step": "3. 归纳步骤 (Inductive Step)",
            "content": """
假设I(i)成立,处理第(i+1)个block:

  新的m_{i+1} = max(m_i, max(S[0:i+1, (i+1)*B:(i+2)*B]))
  
  关键恒等式:
    exp(S - m_{i+1}) = exp(S - m_i) * exp(m_i - m_{i+1})
    
  因此:
    L_{i+1} = L_i * exp(m_i - m_{i+1}) + l_{i+1}
    
    O_{i+1} = (O_i * exp(m_i - m_{i+1}) + O_{i+1,new}) / L_{i+1}
    
  这正是FlashAttention代码中的合并逻辑。
            """,
            "proof_approach": "代入关键恒等式,化简得到合并公式"
        },
        {
            "step": "4. 终止条件 (Termination)",
            "content": """
当所有T个blocks处理完成后:

  最终 m = m_T = max(S)
  最终 L = L_T = Σ exp(S - m)
  最终 O = Σ exp(S - m) / L · V = softmax(S) · V
  
  这正是标准Attention的输出。
  
  因此FlashAttention在数学上等价于标准Attention。
            """,
            "proof_approach": "取i=T,代入不变式I(T)"
        },
        {
            "step": "5. 数值稳定性证明",
            "content": """
关键观察:
  exp(x - m) ≤ 1,当 m = max(x)
  
因此:
  - 所有exp的输入都在(-∞, 0]范围内
  - exp不会溢出(但可能下溢为0,这是安全的)
  
  对于下溢:
    如果exp(x-m) = 0(完全下溢),
    相当于该位置的概率被忽略,
    这在数学上是正确的(概率确实接近0)。
            """,
            "proof_approach": "分析浮点数范围,证明无溢出"
        }
    ]
    
    for step in verification_steps:
        print(f"\n{step['step']}")
        print(f"  内容: {step['content']}")
        print(f"  证明方法: {step['proof_approach']}")

完整验证流程

def run_full_verification(flash_attention_fn):
    """
    运行完整的FlashAttention验证流程
    """
    
    print("\n" + "="*60)
    print("FlashAttention 完整验证流程")
    print("="*60)
    
    # Step 1: 生成测试数据
    test_suite = generate_regression_test_data()
    
    # Step 2: 运行回归测试
    print("\n[Step 1] 回归测试")
    all_passed, _ = test_suite.run_suite(flash_attention_fn)
    
    # Step 3: 边界条件测试
    print("\n[Step 2] 边界条件测试")
    boundary_tester = BoundaryConditionTester()
    boundary_passed = boundary_tester.test_all_boundaries(flash_attention_fn)
    
    # Step 4: 数值等价性测试
    print("\n[Step 3] 数值等价性测试")
    verifier = FlashAttentionVerifier(rtol=1e-3, atol=1e-5)
    
    test_configs = [
        (1, 8, 512, 64),   # 标准配置
        (1, 8, 2048, 64),  # 长序列
        (4, 32, 512, 128), # 大batch
        (1, 1, 128, 128),  # 单head
    ]
    
    all_eq_passed = True
    for B, H, S, D in test_configs:
        q = torch.randn(B, H, S, D)
        k = torch.randn(B, H, S, D)
        v = torch.randn(B, H, S, D)
        
        passed, _, _ = verifier.verify_numerical_equivalence(q, k, v)
        if not passed:
            all_eq_passed = False
    
    # 汇总
    print("\n" + "="*60)
    print("验证汇总")
    print("="*60)
    print(f"回归测试: {'✅ 通过' if all_passed else '❌ 失败'}")
    print(f"边界条件: {'✅ 通过' if boundary_passed else '❌ 失败'}")
    print(f"数值等价: {'✅ 通过' if all_eq_passed else '❌ 失败'}")
    
    final_result = all_passed and boundary_passed and all_eq_passed
    print(f"\n总体结论: {'✅ FlashAttention验证通过' if final_result else '❌ 存在问题,请检查上述失败项'}")
    
    return final_result

总结:FlashAttention正确性验证配置清单

验证类型 测试数量 关键检查点
回归测试 覆盖常见配置 随机、对称、稀疏、重复
边界条件 7类边界 空序列、单token、极端值、非对齐
数值等价 多配置对比 rtol=1e-3, atol=1e-5
Softmax性质 3项数学性质 行和=1、非负、集中性
形式化证明 5个步骤 不变式→归纳→终止→数值稳定性

判断标准

  • 所有测试通过 → 算法实现正确性有保证
  • 边界条件失败 → 可能有未处理的特殊情况
  • 数值等价失败 → 数值精度或算法逻辑有问题

代码和文档:

https://atomgit.com/cann/ops-transformer

Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐