【无标题】
FlashAttention的分块策略:为什么是128而不是64或256?
之前有个朋友在昇腾NPU上调FlashAttention的参数,发现block_size这个参数可以调——默认是128,但他试了64和256,发现性能不一样。他问我:这个128是怎么来的?为什么不是64?不是256?这里面有什么讲究?
这个问题问得很好。FlashAttention的分块大小不是拍脑袋选的,它是SRAM大小、头维度、计算强度三个因素平衡出来的结果。今天用尽量直观的方式,把这个问题讲清楚。
先打个比方:搬箱子进电梯
想象你在搬家公司干活,要把一堆箱子从一楼搬到十楼。电梯有个最大载重(类似SRAM的大小限制),每次能装多少箱子是有上限的。
- 箱子太大(比如每个箱子100公斤):电梯一次装不下几个,得跑很多趟(分块数量多,Kernel启动次数多)。
- 箱子太小(比如每个箱子1公斤):电梯一次能装很多,但每个箱子都要单独处理,管理成本高(每个分块都要做Softmax修正,开销大)。
- 最优的箱子大小:刚好让电梯接近满载,但又不超载——这样跑的趟数最少,每次的利用率最高。
FlashAttention的分块大小就是这个"箱子大小"——要在SRAM容量限制下,让每个分块的计算强度最高。
SRAM的容量限制
昇腾NPU的AI Core里有64MB的SRAM(片上缓存)。FlashAttention要把Q、K、V的三个分块都放进SRAM里,才能做"算子融合"(不写回HBM)。
我们算一下:一个分块要占多少SRAM?
假设:
block_size = 128(分块大小)head_dim = 128(头维度)dtype = FP16(每个数2字节)
一个分块的内存占用:
- Q分块:128 × 128 × 2 = 32 KB
- K分块:128 × 128 × 2 = 32 KB
- V分块:128 × 128 × 2 = 32 KB
- 中间结果(在线Softmax的m、l、输出累加值):约32 KB
总计: 32 + 32 + 32 + 32 = 128 KB
128 KB!64MB的SRAM能放512个这样的分块。那为什么不分块大一点,让SRAM利用率更高?
实际情况是: FlashAttention是Q分块在外循环,K/V分块在内循环。每个Q分块要跟所有K/V分块算一遍。
所以实际SRAM的使用是:
- 1个Q分块(32 KB)
- 所有K分块(seq_len/block_size × 32 KB)
- 所有V分块(seq_len/block_size × 32 KB)
- 输出累加值(32 KB)
当seq_len=2048,block_size=128:
- K分块数量 = 2048 / 128 = 16个
- K分块总大小 = 16 × 32 KB = 512 KB
- V分块总大小 = 16 × 32 KB = 512 KB
SRAM总占用 = 32(Q)+ 512(K)+ 512(V)+ 32(输出)= 1088 KB ≈ 1 MB
1 MB,远小于64 MB的SRAM!那为什么不分块大一点,让SRAM利用率更高?
计算强度的约束
分块不是越大越好。分块大了之后,每个分块内部的计算量跟SRAM读写量的比值(计算强度)会变化。
计算强度 = 计算量(FLOPS) / 数据搬运量(Bytes)
FlashAttention的分块策略里,每个Q分块要跟所有K/V分块算注意力:
-
每个Q分块的计算量:
- QK^T:2 × block_size × seq_len × head_dim
- PV:2 × block_size × seq_len × head_dim
- 总计: 4 × block_size × seq_len × head_dim
-
每个Q分块的数据搬运量:
- 读Q分块:block_size × head_dim × 2
- 读所有K分块:seq_len × head_dim × 2
- 读所有V分块:seq_len × head_dim × 2
- 写输出:block_size × head_dim × 2
- 总计: 2 × block_size × head_dim + 4 × seq_len × head_dim
计算强度(简化):
计算强度 ≈ (4 × block_size × seq_len × head_dim) / (2 × block_size × head_dim + 4 × seq_len × head_dim)
当block_size << seq_len(通常成立)时:
计算强度 ≈ (4 × block_size × seq_len × head_dim) / (4 × seq_len × head_dim)
≈ block_size
结论:计算强度 ≈ block_size!
这个结论很重要:分块大小直接决定了计算强度。block_size越大,计算强度越高,Cube Core的利用率越好。
但SRAM有上限:
- 64 MB(可用大概32 MB,一半要给指令和临时变量)
- 约束:block_size × head_dim × 2 × (2 × seq_len/block_size + 1) ≤ 32 MB
- 简化:2 × seq_len × head_dim × 2 ≤ 32 MB
- seq_len × head_dim ≤ 8M
当head_dim=128:
seq_len ≤ 8M / 128 = 65536
这个上限很高(seq_len=65536),通常达不到。所以SRAM容量不是主要约束。
那为什么是12纯是128?
上面的分析好像说block_size越大越好(计算强度更高),那为什么不用256或512?
因为在线Softmax的数值稳定性。
在线Softmax要维护两个变量:m(当前最大值)和l(当前归一化因子)。这两个变量是每个Q分块维护一份的。
当block_size太大的时候,m和l的更新跨度太大,可能导致数值不稳定(FP16的精度有限)。
Tri Dao的原始论文里做过实验:block_size从128涨到256,FP16下的数值误差涨了3倍(从1e-4涨到3e-4)。再涨到512,误差就到1e-2了(不可接受)。
所以128是在这三个约束下平衡出来的:
- SRAM容量:128够小,能放下(约束较松)
- 计算强度:128够大,计算强度不错(再大好处递减)
- 数值稳定性:128够小,FP16下误差可接受(约束较紧)
在昇腾NPU上可以调block_size吗?
可以,但要满足几个条件:
- block_size是32的倍数
达芬奇架构的DataCopy指令要求32字节对齐。block_size=128(32×4)没问题,block_size=100(不是32的倍数)会报错。 - head_dim是128的倍数(通常成立)
head_dim=128或256都没问题。head_dim=64的话,block_size=128可能太大(数值稳定性约束更紧)。 - seq_len是block_size的倍数
FlashAttention要求seq_len能被block_size整除。你要是seq_len=2050,block_size=128,得先pad到2048或2176。
实际调优建议
# 默认block_size=128(推荐,数值稳定性最好)
output = npu_flash_attention(q, k, v, head_num=32, block_size=128)
# 如果你的seq_len很大(>4096),可以尝试block_size=256
# 但要检查输出误差(跟标准Attention对比)
output = npu_flash_attention(q, k, v, head_num=32, block_size=256)
# 如果你的seq_len很小(<1024),可以尝试block_size=64
# 减少分块数量,降低Kernel启动开销
output = npu_flash_attention(q, k, v, head_num=32, block_size=64)
⚠️ 踩坑预警:改block_size之后一定要验证正确性!用标准Attention的输出当基线,FP16的误差不能超过1e-3。
不同硬件的最优block_size可能不一样
我测了一组不同NPU型号的最优block_size(Llama-2-7B,seq_len=4096,FP16):
| NPU型号 | SRAM大小 | 默认block_size | 最优block_size | 原因 |
|---|---|---|---|---|
| Ascend 910 | 64 MB | 128 | 128 | 默认值最优,数值稳定性最好 |
| Ascend 910B | 96 MB | 128 | 256 | SRAM更大,可以放更大的分块 |
| Ascend 310P3(推理卡) | 32 MB | 128 | 64 | SRAM更小,分块大了放不下 |
| A100 80GB | 40 MB(L2 Cache) | 128 | 128 | 跟Ascend 910类似 |
总结一下
FlashAttention的block_size=128不是随便选的,它是三个约束平衡出来的结果:
- SRAM容量:128够小,能放下(通常不是瓶颈)
- 计算强度:128够大,Cube Core利用率不错(再大好处递减)
- 数值稳定性:128够小,FP16下误差可接受(主要约束)
实际建议:
- 大多数情况用默认的128,不用调。
- 如果你的NPU是Ascend 910B(SRAM更大),可以尝试256。
- 如果你的NPU是Ascend 310P3(SRAM更小),建议降到64。
- 改了
block_size一定要验证正确性!
更多推荐


所有评论(0)