FlashAttention为什么是对的:形式化验证与正确性证明
某团队在昇腾NPU上实现了FlashAttention,测试时发现大多数case都正确,但有几个边界case输出了错误结果——Loss只差一点点,但长序列场景下输出差异明显。他们怀疑是数值精度问题,也怀疑是算法实现有bug。问题是:没有系统性地验证FlashAttention的正确性。测试几个case不能证明算法是对的,需要形式化验证和边界条件覆盖。今天把FlashAttention的正确性验证方
·
某团队在昇腾NPU上实现了FlashAttention,测试时发现大多数case都正确,但有几个边界case输出了错误结果——Loss只差一点点,但长序列场景下输出差异明显。他们怀疑是数值精度问题,也怀疑是算法实现有bug。
问题是:没有系统性地验证FlashAttention的正确性。测试几个case不能证明算法是对的,需要形式化验证和边界条件覆盖。
今天把FlashAttention的正确性验证方法讲清楚,给出完整的测试框架和形式化验证思路。
为什么FlashAttention的正确性需要专门验证
标准Attention vs FlashAttention的数学等价性
标准Attention:
S[i,j] = Q[i] · K[j] / √d
P[i,j] = exp(S[i,j]) / Σexp(S[i,*])
O[i] = Σj P[i,j] · V[j]
FlashAttention(分块计算):
把序列分成T个blocks
每个block Ti计算:
mi = rowmax(Si,*), li = rowsum(exp(Si,* - mi))
最终合并:
O = Σi exp(mi - m) / L · Oi
L = Σi exp(mi - m) / li
数学上要求:
Σi exp(mi - m) / L · Oi = Σj P[i,j] · V[j]
这个等式的成立需要:
1. 每个block的m和l计算正确
2. 合并公式的数学推导正确
3. 数值稳定性处理(exp溢出)正确
4. 所有block的合并顺序正确
数值等价性验证
分层测试框架
import torch
import numpy as np
from typing import List, Tuple
class FlashAttentionVerifier:
"""
FlashAttention正确性验证器
验证策略:
1. 数值等价性:跟标准Attention对比输出
2. 数学性质:验证softmax的数学性质(和为1、非负等)
3. 边界条件:空序列、单token、极端值
4. 回归测试:确保改动不引入新的bug
"""
def __init__(self, rtol=1e-3, atol=1e-5):
self.rtol = rtol # 相对误差容限
self.atol = atol # 绝对误差容限
def standard_attention(self, q, k, v):
"""
标准Attention实现(参考实现)
"""
d = q.shape[-1]
scale = 1.0 / (d ** 0.5)
# S = Q @ K^T
scores = torch.matmul(q, k.transpose(-2, -1)) * scale
# P = softmax(S)
# 数值稳定的softmax
scores_max = scores.amax(dim=-1, keepdim=True)
scores_exp = torch.exp(scores - scores_max)
P = scores_exp / scores_exp.sum(dim=-1, keepdim=True)
# O = P @ V
O = torch.matmul(P, v)
return O, P
def flash_attention(self, q, k, v, block_size=128):
"""
FlashAttention实现(待验证)
"""
B, H, S, D = q.shape
# 初始化
O = torch.zeros_like(q)
m = torch.full((B, H, S, 1), -float('inf'), device=q.device)
l = torch.zeros((B, H, S, 1), device=q.device)
# 分block计算
num_blocks = (S + block_size - 1) // block_size
for i in range(num_blocks):
start_i = i * block_size
end_i = min(start_i + block_size, S)
# 读取block_i的K、V
k_i = k[:, :, start_i:end_i, :]
v_i = v[:, :, start_i:end_i, :]
# 计算block_i的m和l
scale = 1.0 / (D ** 0.5)
scores_i = torch.matmul(q, k_i.transpose(-2, -1)) * scale
m_i = scores_i.amax(dim=-1, keepdim=True)
p_i = torch.exp(scores_i - m_i)
l_i = p_i.sum(dim=-1, keepdim=True)
# 合并
m_new = torch.maximum(m[:, :, start_i:end_i], m_i)
# 防止m全为-inf
m_safe = torch.where(torch.isinf(m_new), torch.zeros_like(m_new), m_new)
p_i_scaled = torch.exp(m_i - m_safe)
l_i_scaled = p_i_scaled * l_i
# 旧的贡献
old_scale = torch.exp(m[:, :, start_i:end_i] - m_safe)
old_scale = torch.where(torch.isinf(old_scale), torch.zeros_like(old_scale), old_scale)
O_old = O[:, :, start_i:end_i] * old_scale
l_old = l[:, :, start_i:end_i] * old_scale
# 新的贡献
O_new = torch.matmul(p_i_scaled / (l_i_scaled + 1e-10), v_i)
# 更新
l_combined = l_i_scaled + l_old
O[:, :, start_i:end_i] = (O_old + O_new) / (l_combined + 1e-10)
m[:, :, start_i:end_i] = m_safe
l[:, :, start_i:end_i] = l_combined
return O, None
def verify_numerical_equivalence(self, q, k, v):
"""
验证:FlashAttention输出跟标准Attention等价
"""
O_ref, P_ref = self.standard_attention(q, k, v)
O_fa, _ = self.flash_attention(q, k, v)
# 计算误差
abs_diff = (O_ref - O_fa).abs()
rel_diff = abs_diff / (O_ref.abs() + 1e-8)
max_abs_diff = abs_diff.max().item()
max_rel_diff = rel_diff.max().item()
mean_rel_diff = rel_diff.mean().item()
passed = (max_rel_diff < self.rtol or max_abs_diff < self.atol)
print(f"\n=== 数值等价性验证 ===")
print(f"最大绝对误差: {max_abs_diff:.2e}")
print(f"最大相对误差: {max_rel_diff:.2e}")
print(f"平均相对误差: {mean_rel_diff:.2e}")
print(f"容限: rtol={self.rtol}, atol={self.atol}")
print(f"结果: {'✅ 通过' if passed else '❌ 失败'}")
return passed, O_ref, O_fa
def verify_softmax_properties(self, attn_weights):
"""
验证softmax的数学性质
性质1:每行和为1
性质2:所有元素非负
性质3:最大值位置概率最高
"""
print(f"\n=== Softmax数学性质验证 ===")
errors = []
# 性质1:每行和为1
row_sums = attn_weights.sum(dim=-1)
sum_error = (row_sums - 1.0).abs().max().item()
print(f"1. 行和=1: max误差={sum_error:.2e} {'✅' if sum_error < 1e-4 else '❌'}")
if sum_error > 1e-4:
errors.append(f"行和误差: {sum_error}")
# 性质2:非负
min_val = attn_weights.min().item()
print(f"2. 非负: min={min_val:.2e} {'✅' if min_val >= -1e-6 else '❌'}")
if min_val < -1e-6:
errors.append(f"负值: {min_val}")
# 性质3:最大值位置概率最高
max_idx = attn_weights.argmax(dim=-1)
max_val = attn_weights.amax(dim=-1)
diag_check = torch.all(max_val > attn_weights.mean(dim=-1))
print(f"3. 集中性: max>mean {'✅' if diag_check else '❌'}")
return len(errors) == 0
边界条件测试
class BoundaryConditionTester:
"""
边界条件测试
覆盖场景:
1. 空序列(长度为0)
2. 单token序列
3. 单head
4. 极端值(无穷大、NaN、接近0)
5. 非对齐长度(非block_size倍数)
6. batch_size>1
"""
def test_all_boundaries(self, flash_attention_fn):
"""运行所有边界测试"""
results = []
print("\n=== 边界条件测试 ===")
# 测试1:空序列
results.append(("空序列", self._test_empty_sequence(flash_attention_fn)))
# 测试2:单token
results.append(("单token", self._test_single_token(flash_attention_fn)))
# 测试3:单head
results.append(("单head", self._test_single_head(flash_attention_fn)))
# 测试4:极端数值
results.append(("极端数值", self._test_extreme_values(flash_attention_fn)))
# 测试5:非对齐长度
results.append(("非对齐长度", self._test_unaligned_length(flash_attention_fn)))
# 测试6:大批量
results.append(("大批量", self._test_large_batch(flash_attention_fn)))
# 测试7:极长序列
results.append(("极长序列", self._test_very_long_sequence(flash_attention_fn)))
# 汇总
print("\n=== 边界测试汇总 ===")
all_passed = True
for name, (passed, msg) in results:
status = "✅" if passed else "❌"
print(f"{status} {name}: {msg}")
if not passed:
all_passed = False
return all_passed
def _test_single_token(self, fn):
"""测试单token序列"""
q = torch.randn(1, 1, 1, 128)
k = torch.randn(1, 1, 1, 128)
v = torch.randn(1, 1, 1, 128)
try:
O = fn(q, k, v)
# 单token时,attention应该是全1(因为只有一个key)
# O = V[0] * 1 = V[0]
expected = v
diff = (O - expected).abs().max().item()
passed = diff < 1e-4
return passed, f"max_diff={diff:.2e}"
except Exception as e:
return False, str(e)
def _test_extreme_values(self, fn):
"""测试极端数值"""
test_cases = [
("全0", torch.zeros(1, 1, 4, 128), torch.zeros(1, 1, 4, 128), torch.zeros(1, 1, 4, 128)),
("极大值", torch.full((1, 1, 4, 128), 1e3), torch.full((1, 1, 4, 128), 1e3), torch.ones(1, 1, 4, 128)),
("极小值", torch.full((1, 1, 4, 128), 1e-10), torch.full((1, 1, 4, 128), 1e-10), torch.ones(1, 1, 4, 128)),
("混合", torch.randn(1, 1, 4, 128) * 1e2, torch.randn(1, 1, 4, 128) * 1e2, torch.randn(1, 1, 4, 128)),
]
for name, q, k, v in test_cases:
try:
O = fn(q, k, v)
if torch.isnan(O).any() or torch.isinf(O).any():
return False, f"{name}产生NaN/Inf"
except Exception as e:
return False, f"{name}: {e}"
return True, "所有极端值通过"
def _test_unaligned_length(self, fn):
"""测试非block_size倍数的序列"""
# block_size=128,测试129和255(都不是128的倍数)
for seq_len in [129, 255, 300, 1000]:
q = torch.randn(1, 1, seq_len, 128)
k = torch.randn(1, 1, seq_len, 128)
v = torch.randn(1, 1, seq_len, 128)
try:
O = fn(q, k, v)
if O.shape != q.shape:
return False, f"seq_len={seq_len}输出shape不匹配"
except Exception as e:
return False, f"seq_len={seq_len}: {e}"
return True, "所有非对齐长度通过"
def _test_very_long_sequence(self, fn):
"""测试超长序列"""
# 模拟16384长度的序列
seq_len = 2048 # 缩小测试规模
q = torch.randn(1, 1, seq_len, 128)
k = torch.randn(1, 1, seq_len, 128)
v = torch.randn(1, 1, seq_len, 128)
try:
O = fn(q, k, v)
# 验证:最后一行的attention只看自己(因果mask)
# 由于是自回归,实际上最后一个query只看最后一个key
if O.shape == q.shape:
return True, f"seq_len={seq_len}通过"
else:
return False, "输出shape不匹配"
except Exception as e:
return False, str(e)
回归测试框架
class RegressionTestSuite:
"""
回归测试套件
每次代码改动后运行,确保不引入新的bug
"""
def __init__(self):
self.test_cases = []
self.results_history = [] # 保存历史结果用于对比
def add_test_case(self, name, q, k, v, expected_O=None, rtol=1e-3, atol=1e-5):
"""添加测试用例"""
self.test_cases.append({
"name": name,
"q": q,
"k": k,
"v": v,
"expected_O": expected_O,
"rtol": rtol,
"atol": atol
})
def run_suite(self, flash_attention_fn, save_to_history=True):
"""运行完整测试套件"""
print("\n" + "="*60)
print("FlashAttention 回归测试套件")
print("="*60)
results = []
failed_cases = []
for i, tc in enumerate(self.test_cases):
name = tc["name"]
try:
O_actual = flash_attention_fn(tc["q"], tc["k"], tc["v"])
# 如果有expected_O,验证
if tc["expected_O"] is not None:
abs_diff = (O_actual - tc["expected_O"]).abs()
rel_diff = abs_diff / (tc["expected_O"].abs() + 1e-8)
max_rel = rel_diff.max().item()
max_abs = abs_diff.max().item()
passed = (max_rel < tc["rtol"] or max_abs < tc["atol"])
else:
# 无expected_O,只验证不报错
passed = not (torch.isnan(O_actual).any() or torch.isinf(O_actual).any())
max_rel = max_abs = 0.0
status = "✅" if passed else "❌"
print(f"{status} [{i+1}/{len(self.test_cases)}] {name}")
if not passed:
print(f" max_rel_diff={max_rel:.2e}, max_abs_diff={max_abs:.2e}")
failed_cases.append(name)
results.append({
"name": name,
"passed": passed,
"max_rel": max_rel,
"max_abs": max_abs
})
except Exception as e:
print(f"❌ [{i+1}/{len(self.test_cases)}] {name}: {e}")
results.append({
"name": name,
"passed": False,
"error": str(e)
})
failed_cases.append(name)
# 统计
passed_count = sum(1 for r in results if r["passed"])
total_count = len(results)
print("\n" + "="*60)
print(f"测试结果: {passed_count}/{total_count} 通过")
if failed_cases:
print(f"失败用例: {failed_cases}")
else:
print("✅ 全部通过!")
if save_to_history:
self.results_history.append({
"timestamp": str(torch.cuda.Event(ElapsedTime=0)), # 占位
"results": results
})
return passed_count == total_count, results
def generate_regression_test_data():
"""生成回归测试数据"""
print("\n=== 生成回归测试数据 ===")
test_suite = RegressionTestSuite()
# Case 1: 标准随机
for seq_len in [4, 16, 64, 256, 1024]:
for dtype in [torch.float16, torch.float32]:
q = torch.randn(2, 8, seq_len, 64, dtype=dtype)
k = torch.randn(2, 8, seq_len, 64, dtype=dtype)
v = torch.randn(2, 8, seq_len, 64, dtype=dtype)
test_suite.add_test_case(
f"random_seq{seq_len}_{dtype}",
q, k, v
)
# Case 2: 对称矩阵
q = torch.randn(1, 4, 32, 64)
k = q.clone()
v = torch.randn(1, 4, 32, 64)
test_suite.add_test_case("symmetric_qk", q, k, v)
# Case 3: 稀疏注意力模式
q = torch.randn(1, 4, 128, 64)
k = torch.zeros(1, 4, 128, 64)
k[:, :, ::4, :] = torch.randn(1, 4, 32, 64) # 每4个token一个key
v = torch.randn(1, 4, 128, 64)
test_suite.add_test_case("sparse_attention", q, k, v)
# Case 4: 重复token
q = torch.randn(1, 4, 64, 64)
k = torch.cat([q[:, :, :32, :]] * 2, dim=2) # 前32个token重复
v = torch.randn(1, 4, 64, 64)
test_suite.add_test_case("repeated_keys", q, k, v)
print(f"已生成 {len(test_suite.test_cases)} 个测试用例")
return test_suite
形式化验证思路
def formal_verification_outline():
"""
FlashAttention形式化验证大纲
目标:数学上证明FlashAttention等价于标准Attention
"""
print("\n=== FlashAttention形式化验证大纲 ===")
verification_steps = [
{
"step": "1. 定义不变式 (Loop Invariant)",
"content": """
在第i个block处理完成后,定义不变式I(i):
I(i):
O[0:i] = Σ(j=0 to i-1) exp(S[0:i,j] - m_i) / L_i · V[j]
m_i = max(S[0:i,*])
L_i = Σ(j=0 to i-1) exp(S[0:i,j] - m_i)
其中:
S[0:i,*] 是前i个query对所有keys的scores
O[0:i] 是前i个query的输出
m_i 和 L_i 是数值稳定的中间状态
""",
"proof_approach": "数学归纳法:证明I(0)成立,且I(i)→I(i+1)"
},
{
"step": "2. 初始化 (Base Case)",
"content": """
当i=0时,没有处理任何block:
m_0 = -∞
L_0 = 0
O[0:0] = 空矩阵
此时I(0)平凡成立。
""",
"proof_approach": "直接代入验证"
},
{
"step": "3. 归纳步骤 (Inductive Step)",
"content": """
假设I(i)成立,处理第(i+1)个block:
新的m_{i+1} = max(m_i, max(S[0:i+1, (i+1)*B:(i+2)*B]))
关键恒等式:
exp(S - m_{i+1}) = exp(S - m_i) * exp(m_i - m_{i+1})
因此:
L_{i+1} = L_i * exp(m_i - m_{i+1}) + l_{i+1}
O_{i+1} = (O_i * exp(m_i - m_{i+1}) + O_{i+1,new}) / L_{i+1}
这正是FlashAttention代码中的合并逻辑。
""",
"proof_approach": "代入关键恒等式,化简得到合并公式"
},
{
"step": "4. 终止条件 (Termination)",
"content": """
当所有T个blocks处理完成后:
最终 m = m_T = max(S)
最终 L = L_T = Σ exp(S - m)
最终 O = Σ exp(S - m) / L · V = softmax(S) · V
这正是标准Attention的输出。
因此FlashAttention在数学上等价于标准Attention。
""",
"proof_approach": "取i=T,代入不变式I(T)"
},
{
"step": "5. 数值稳定性证明",
"content": """
关键观察:
exp(x - m) ≤ 1,当 m = max(x)
因此:
- 所有exp的输入都在(-∞, 0]范围内
- exp不会溢出(但可能下溢为0,这是安全的)
对于下溢:
如果exp(x-m) = 0(完全下溢),
相当于该位置的概率被忽略,
这在数学上是正确的(概率确实接近0)。
""",
"proof_approach": "分析浮点数范围,证明无溢出"
}
]
for step in verification_steps:
print(f"\n{step['step']}")
print(f" 内容: {step['content']}")
print(f" 证明方法: {step['proof_approach']}")
完整验证流程
def run_full_verification(flash_attention_fn):
"""
运行完整的FlashAttention验证流程
"""
print("\n" + "="*60)
print("FlashAttention 完整验证流程")
print("="*60)
# Step 1: 生成测试数据
test_suite = generate_regression_test_data()
# Step 2: 运行回归测试
print("\n[Step 1] 回归测试")
all_passed, _ = test_suite.run_suite(flash_attention_fn)
# Step 3: 边界条件测试
print("\n[Step 2] 边界条件测试")
boundary_tester = BoundaryConditionTester()
boundary_passed = boundary_tester.test_all_boundaries(flash_attention_fn)
# Step 4: 数值等价性测试
print("\n[Step 3] 数值等价性测试")
verifier = FlashAttentionVerifier(rtol=1e-3, atol=1e-5)
test_configs = [
(1, 8, 512, 64), # 标准配置
(1, 8, 2048, 64), # 长序列
(4, 32, 512, 128), # 大batch
(1, 1, 128, 128), # 单head
]
all_eq_passed = True
for B, H, S, D in test_configs:
q = torch.randn(B, H, S, D)
k = torch.randn(B, H, S, D)
v = torch.randn(B, H, S, D)
passed, _, _ = verifier.verify_numerical_equivalence(q, k, v)
if not passed:
all_eq_passed = False
# 汇总
print("\n" + "="*60)
print("验证汇总")
print("="*60)
print(f"回归测试: {'✅ 通过' if all_passed else '❌ 失败'}")
print(f"边界条件: {'✅ 通过' if boundary_passed else '❌ 失败'}")
print(f"数值等价: {'✅ 通过' if all_eq_passed else '❌ 失败'}")
final_result = all_passed and boundary_passed and all_eq_passed
print(f"\n总体结论: {'✅ FlashAttention验证通过' if final_result else '❌ 存在问题,请检查上述失败项'}")
return final_result
总结:FlashAttention正确性验证配置清单
| 验证类型 | 测试数量 | 关键检查点 |
|---|---|---|
| 回归测试 | 覆盖常见配置 | 随机、对称、稀疏、重复 |
| 边界条件 | 7类边界 | 空序列、单token、极端值、非对齐 |
| 数值等价 | 多配置对比 | rtol=1e-3, atol=1e-5 |
| Softmax性质 | 3项数学性质 | 行和=1、非负、集中性 |
| 形式化证明 | 5个步骤 | 不变式→归纳→终止→数值稳定性 |
判断标准:
- 所有测试通过 → 算法实现正确性有保证
- 边界条件失败 → 可能有未处理的特殊情况
- 数值等价失败 → 数值精度或算法逻辑有问题
代码和文档:
https://atomgit.com/cann/ops-transformer
更多推荐



所有评论(0)