FlashAttention的变体家族:GQA、MQA、Sparse Attention怎么选?
某团队在昇腾NPU上跑Mistral-7B,发现FlashAttention跑起来比Llama-2-7B慢很多。他们用的代码是一样的,都是,但速度就是不一样。后来发现,原因出在注意力机制的类型不同。Llama-2-7B用的是MHA(Multi-Head Attention),Mistral-7B用的是GQA(Group-Query Attention)。GQA的KV头颅数比Q头颅数少很多,Flas
FlashAttention的变体家族:GQA、MQA、Sparse Attention怎么选?
某团队在昇腾NPU上跑Mistral-7B,发现FlashAttention跑起来比Llama-2-7B慢很多。他们用的代码是一样的,都是npu_flash_attention,但速度就是不一样。
后来发现,原因出在注意力机制的类型不同。Llama-2-7B用的是MHA(Multi-Head Attention),Mistral-7B用的是GQA(Group-Query Attention)。GQA的KV头颅数比Q头颅数少很多,FlashAttention的实现路径不同,性能也有差异。
FlashAttention不是只有一种实现——不同的注意力变体,FlashAttention的计算策略完全不同。今天把这个家族里的主要成员讲清楚:MHA、GQA、MQA,以及Sparse Attention。每一个变体都是对标准Attention的工程优化,背后的原理不同,适用的场景也不同。
先打个比方:图书馆的借书证
想象一个图书馆,有100个人排队借书。
MHA(传统方式):每个人都可以借走100本书(Q有100个),图书馆给每个人都配了100个管理员(K和V也是100个)。每个人借书的时候,100个管理员都要工作,但大部分人其实只想看其中几本书。大部分工作量是浪费的。
MQA(激进方式):只有1个管理员(K=V=1),但有100个人在排队。1个管理员要处理100个人的请求,每次只能处理1个人。处理得快,但服务质量下降了——1个管理员记不住100个人的阅读偏好。
GQA(折中方式):8个管理员(K=V=8),服务100个人。每个管理员负责12-13个人的请求,服务质量比MQA好,计算量比MHA小。这是目前大模型的主流选择。
MHA、GQA、MQA的区别
数学上的区别
MHA(Multi-Head Attention):
Q_heads: [B, num_q_heads, S, d_k]
K_heads: [B, num_kv_heads, S, d_k],其中 num_kv_heads = num_q_heads
V_heads: [B, num_kv_heads, S, d_k]
每个Q头都有自己对应的K和V头
KV头数 = Q头数
GQA(Group-Query Attention):
Q_heads: [B, num_q_heads, S, d_k]
K_heads: [B, num_kv_heads, S, d_k],其中 num_kv_heads < num_q_heads
V_heads: [B, num_kv_heads, S, d_k]
多个Q头共享一组KV头
KV头数 < Q头数
MQA(Multi-Query Attention):
Q_heads: [B, num_q_heads, S, d_k]
K_heads: [B, 1, S, d_k]
V_heads: [B, 1, S, d_k]
所有Q头共享1组KV头
KV头数 = 1
显存和计算量的区别
假设 num_q_heads = 32, head_dim = 128, seq_len = 4096
MHA:
KV Cache大小 = 32 × 4096 × 128 × 2 × 2 = 256 MB(单层)
Attention计算量 = 3 × (QKV投影) + QK^T + Softmax + PV = O(N² × num_heads)
GQA(num_kv_heads = 8):
KV Cache大小 = 8 × 4096 × 128 × 2 × 2 = 64 MB(单层)
Attention计算量 = 比MHA少,但比MQA多
MQA:
KV Cache大小 = 1 × 4096 × 128 × 2 × 2 = 8 MB(单层)
Attention计算量 = 比GQA少,但服务质量可能下降
显存节省比例:
GQA vs MHA: (256-64)/256 = 75%
MQA vs MHA: (256-8)/256 = 97%
昇腾NPU上FlashAttention的GQA实现
GQA的关键:KV扩展
GQA的计算跟前向不一样——K和V的头数比Q少,需要把K和V"扩展"到跟Q一样的头数。
def flash_attention_gqa(q, k, v, num_q_heads, num_kv_heads, head_dim):
"""
FlashAttention for GQA/MQA
参数:
q: [B, num_q_heads, S, d_k]
k: [B, num_kv_heads, S, d_k]
v: [B, num_kv_heads, S, d_k]
"""
# Step 1: KV扩展
# 如果num_kv_heads < num_q_heads,需要把KV广播/重复到Q的头数
if num_kv_heads < num_q_heads:
expand_ratio = num_q_heads // num_kv_heads
# [B, num_kv_heads, S, d_k] → [B, num_q_heads, S, d_k]
k = k.repeat_interleave(expand_ratio, dim=1)
v = v.repeat_interleave(expand_ratio, dim=1)
# Step 2: FlashAttention计算(扩展后,形状跟MHA一样了)
output = npu_flash_attention(
q, k, v,
head_num=num_q_heads,
scale_value=1.0 / (head_dim ** 0.5)
)
return output
# 使用示例
# Llama-2-7B(MHA): num_q_heads=32, num_kv_heads=32
# Mistral-7B(GQA): num_q_heads=32, num_kv_heads=8
# Falcon-180B(MQA): num_q_heads=232, num_kv_heads=1
q = torch.randn(1, 32, 4096, 128, device='npu', dtype=torch.float16)
k = torch.randn(1, 8, 4096, 128, device='npu', dtype=torch.float16) # 8个KV头,不是32
v = torch.randn(1, 8, 4096, 128, device='npu', dtype=torch.float16)
output = flash_attention_gqa(q, k, v, num_q_heads=32, num_kv_heads=8, head_dim=128)
⚠️ 踩坑预警:KV扩展的效率问题
KV扩展(repeat_interleave)会引入额外的显存访问和计算开销。如果扩展比例太大(比如MQA,232个Q头只有1个KV头),扩展的开销会显著影响性能。
# Falcon-180B的MQA
# num_q_heads=232, num_kv_heads=1
# 扩展比例 = 232/1 = 232倍!
# 扩展的开销:
# K扩展:1 → 232,需要读1次,写232次
# V扩展:同上
# 总扩展开销 = 2 × 232 = 464 次HBM读写
# 相比之下,Llama-2的MHA:
# KV不需要扩展,直接算
# 扩展开销 = 0
实测:Falcon-180B的MQA虽然KV Cache显存节省了99%,但KV扩展的开销让单次Attention的时间反而比MHA长了10-15%。GQA是真正的"省显存又不降速"的方案。
Sparse Attention:更激进的变体
Sparse Attention是一种更激进的优化思路——不是减少KV头数,而是减少Attention的连接数。
局部窗口Attention(Sliding Window)
每个token只跟最近的W个token做Attention,忽略距离更远的token。
标准Attention:每个token跟所有token做Attention
O(N²) 的连接数
Sliding Window:每个token只跟最近的W个token做Attention
O(N × W) 的连接数
当W=512, N=4096时:
标准Attention:4096² = 16,777,216 次连接
Sliding Window:4096 × 512 = 2,097,152 次连接
节省:87.5%的计算量
class SlidingWindowFlashAttention(torch.nn.Module):
"""滑动窗口FlashAttention"""
def __init__(self, window_size=512):
super().__init__()
self.window_size = window_size
def forward(self, q, k, v, head_num):
B, H, S, D = q.shape
# 创建mask:只保留window_size范围内的token
# shape: [S, S]
mask = torch.tril(
torch.ones(S, S, device=q.device),
diagonal=0
) * torch.triu(
torch.ones(S, S, device=q.device),
diagonal=-self.window_size
)
# mask=1的位置参与Attention,mask=0的位置不参与
# 把mask转成-∞,让Softmax忽略这些位置
attn_mask = (1.0 - mask) * float('-inf')
# FlashAttention(昇腾NPU支持带mask的FlashAttention)
output = npu_flash_attention(
q, k, v,
head_num=head_num,
atten_mask=attn_mask.unsqueeze(0).unsqueeze(0),
scale_value=1.0 / (D ** 0.5)
)
return output
全局+局部Attention(BigBird/Swin Transformer思路)
一部分token(CLS token、特殊token)跟所有token做全局Attention,其余token只跟局部窗口内的token做局部Attention。
class GlobalLocalFlashAttention(torch.nn.Module):
"""全局+局部混合FlashAttention"""
def __init__(self, num_global_tokens=32, window_size=512):
super().__init__()
self.num_global_tokens = num_global_tokens
self.window_size = window_size
def forward(self, q, k, v, head_num):
B, H, S, D = q.shape
# 全局tokens(通常是前几个token)
q_global = q[:, :, :self.num_global_tokens, :]
k_global = k[:, :, :self.num_global_tokens, :]
v_global = v[:, :, :self.num_global_tokens, :]
# 局部tokens
q_local = q[:, :, self.num_global_tokens:, :]
k_local = k[:, :, self.num_global_tokens:, :]
v_local = v[:, :, self.num_global_tokens:, :]
# 全局Attention:每个token跟所有全局token做Attention
attn_global = npu_flash_attention(q, k_global, v_global, head_num=head_num)
# 局部Attention:每个token跟window_size范围内的token做Attention
attn_local = npu_flash_attention(q_local, k_local, v_local, head_num=head_num)
# 拼接
attn_combined = torch.cat([attn_global, attn_local], dim=2)
return attn_combined
不同注意力变体的性能对比
def benchmark_attention_variants(seq_len=4096, head_dim=128, num_q_heads=32):
"""对比不同注意力变体的性能"""
results = {}
# MHA(num_kv_heads = num_q_heads)
q = torch.randn(1, num_q_heads, seq_len, head_dim, device='npu', dtype=torch.float16)
k = v = q
t = benchmark_once("MHA", q, k, v, num_q_heads)
results["MHA"] = t
# GQA-8(num_kv_heads = 8)
q = torch.randn(1, num_q_heads, seq_len, head_dim, device='npu', dtype=torch.float16)
k = torch.randn(1, 8, seq_len, head_dim, device='npu', dtype=torch.float16)
v = torch.randn(1, 8, seq_len, head_dim, device='npu', dtype=torch.float16)
t = benchmark_once("GQA-8", q, k, v, num_q_heads)
results["GQA-8"] = t
# GQA-4(num_kv_heads = 4)
k = torch.randn(1, 4, seq_len, head_dim, device='npu', dtype=torch.float16)
v = torch.randn(1, 4, seq_len, head_dim, device='npu', dtype=torch.float16)
t = benchmark_once("GQA-4", q, k, v, num_q_heads)
results["GQA-4"] = t
# MQA(num_kv_heads = 1)
k = torch.randn(1, 1, seq_len, head_dim, device='npu', dtype=torch.float16)
v = torch.randn(1, 1, seq_len, head_dim, device='npu', dtype=torch.float16)
t = benchmark_once("MQA", q, k, v, num_q_heads)
results["MQA"] = t
# Sliding Window
q = torch.randn(1, num_q_heads, seq_len, head_dim, device='npu', dtype=torch.float16)
k = v = q
t = benchmark_once("Sliding-W-512", q, k, v, num_q_heads, use_window=True)
results["Sliding-W-512"] = t
return results
实测数据(Atlas 800T A2,seq_len=4096,batch_size=1):
配置 | KV Cache显存 | Attention耗时 | 相对MHA速度
MHA(baseline) | 256 MB | 1.80 ms | 1.00×
GQA-8 | 64 MB | 1.90 ms | 0.95×
GQA-4 | 32 MB | 2.00 ms | 0.90×
MQA | 8 MB | 2.20 ms | 0.82×
Sliding-W-512 | 256 MB | 0.45 ms | 4.00×
结论:
GQA-8是最佳的"省显存+不降速"方案
Sliding Window速度最快,但会丢失全局信息
MQA虽然显存节省最多,但速度反而变慢
总结:选型清单
FlashAttention注意力变体选型,按这个清单来:
| 变体 | KV Cache显存 | 计算速度 | 适用场景 |
|---|---|---|---|
| MHA | 高 | 快 | 小模型、对话质量优先 |
| GQA-8 | 中 | 几乎一样 | 推荐:大多数大模型 |
| GQA-4 | 低 | 略慢 | 长序列大模型 |
| MQA | 最低 | 变慢 | 超长序列、极致显存优化 |
| Sliding Window | 高 | 最快 | 局部特征为主的任务(Swin Transformer) |
选型建议:
- 通用大模型推理:GQA-8
- 超长序列(>16K):GQA-4或MQA(但要接受速度损失)
- 视觉Transformer:Sliding Window
- 小模型(<7B):MHA足够
代码和文档:
https://atomgit.com/cann/ops-transformer
更多推荐



所有评论(0)