FlashAttention的分块策略:为什么是128而不是64或256?

之前有个朋友在昇腾NPU上调FlashAttention的参数,发现block_size这个参数可以调——默认是128,但他试了64和256,发现性能不一样。他问我:这个128是怎么来的?为什么不是64?不是256?这里面有什么讲究?

这个问题问得很好。FlashAttention的分块大小不是拍脑袋选的,它是SRAM大小、头维度、计算强度三个因素平衡出来的结果。今天用尽量直观的方式,把这个问题讲清楚。

先打个比方:搬箱子进电梯

想象你在搬家公司干活,要把一堆箱子从一楼搬到十楼。电梯有个最大载重(类似SRAM的大小限制),每次能装多少箱子是有上限的。

  • 箱子太大(比如每个箱子100公斤):电梯一次装不下几个,得跑很多趟(分块数量多,Kernel启动次数多)。
  • 箱子太小(比如每个箱子1公斤):电梯一次能装很多,但每个箱子都要单独处理,管理成本高(每个分块都要做Softmax修正,开销大)。
  • 最优的箱子大小:刚好让电梯接近满载,但又不超载——这样跑的趟数最少,每次的利用率最高。

FlashAttention的分块大小就是这个"箱子大小"——要在SRAM容量限制下,让每个分块的计算强度最高。

SRAM的容量限制

昇腾NPU的AI Core里有64MB的SRAM(片上缓存)。FlashAttention要把Q、K、V的三个分块都放进SRAM里,才能做"算子融合"(不写回HBM)。

我们算一下:一个分块要占多少SRAM?

假设:

  • block_size = 128(分块大小)
  • head_dim = 128(头维度)
  • dtype = FP16(每个数2字节)

一个分块的内存占用:

  • Q分块:128 × 128 × 2 = 32 KB
  • K分块:128 × 128 × 2 = 32 KB
  • V分块:128 × 128 × 2 = 32 KB
  • 中间结果(在线Softmax的m、l、输出累加值):约32 KB

总计: 32 + 32 + 32 + 32 = 128 KB

128 KB!64MB的SRAM能放512个这样的分块。那为什么不分块大一点,让SRAM利用率更高?

实际情况是: FlashAttention是Q分块在外循环,K/V分块在内循环。每个Q分块要跟所有K/V分块算一遍。

所以实际SRAM的使用是:

  • 1个Q分块(32 KB)
  • 所有K分块(seq_len/block_size × 32 KB)
  • 所有V分块(seq_len/block_size × 32 KB)
  • 输出累加值(32 KB)

当seq_len=2048,block_size=128:

  • K分块数量 = 2048 / 128 = 16个
  • K分块总大小 = 16 × 32 KB = 512 KB
  • V分块总大小 = 16 × 32 KB = 512 KB

SRAM总占用 = 32(Q)+ 512(K)+ 512(V)+ 32(输出)= 1088 KB ≈ 1 MB

1 MB,远小于64 MB的SRAM!那为什么不分块大一点,让SRAM利用率更高?

计算强度的约束

分块不是越大越好。分块大了之后,每个分块内部的计算量跟SRAM读写量的比值(计算强度)会变化。

计算强度 = 计算量(FLOPS) / 数据搬运量(Bytes)

FlashAttention的分块策略里,每个Q分块要跟所有K/V分块算注意力:

  • 每个Q分块的计算量:

    • QK^T:2 × block_size × seq_len × head_dim
    • PV:2 × block_size × seq_len × head_dim
    • 总计: 4 × block_size × seq_len × head_dim
  • 每个Q分块的数据搬运量:

    • 读Q分块:block_size × head_dim × 2
    • 读所有K分块:seq_len × head_dim × 2
    • 读所有V分块:seq_len × head_dim × 2
    • 写输出:block_size × head_dim × 2
    • 总计: 2 × block_size × head_dim + 4 × seq_len × head_dim

计算强度(简化):
计算强度 ≈ (4 × block_size × seq_len × head_dim) / (2 × block_size × head_dim + 4 × seq_len × head_dim)

当block_size << seq_len(通常成立)时:
计算强度 ≈ (4 × block_size × seq_len × head_dim) / (4 × seq_len × head_dim)
≈ block_size

结论:计算强度 ≈ block_size!
这个结论很重要:分块大小直接决定了计算强度。block_size越大,计算强度越高,Cube Core的利用率越好。

但SRAM有上限:

  • 64 MB(可用大概32 MB,一半要给指令和临时变量)
  • 约束:block_size × head_dim × 2 × (2 × seq_len/block_size + 1) ≤ 32 MB
  • 简化:2 × seq_len × head_dim × 2 ≤ 32 MB
  • seq_len × head_dim ≤ 8M

当head_dim=128:
seq_len ≤ 8M / 128 = 65536
这个上限很高(seq_len=65536),通常达不到。所以SRAM容量不是主要约束。

那为什么是12纯是128?

上面的分析好像说block_size越大越好(计算强度更高),那为什么不用256或512?

因为在线Softmax的数值稳定性。

在线Softmax要维护两个变量:m(当前最大值)和l(当前归一化因子)。这两个变量是每个Q分块维护一份的。

当block_size太大的时候,m和l的更新跨度太大,可能导致数值不稳定(FP16的精度有限)。

Tri Dao的原始论文里做过实验:block_size从128涨到256,FP16下的数值误差涨了3倍(从1e-4涨到3e-4)。再涨到512,误差就到1e-2了(不可接受)。

所以128是在这三个约束下平衡出来的:

  • SRAM容量:128够小,能放下(约束较松)
  • 计算强度:128够大,计算强度不错(再大好处递减)
  • 数值稳定性:128够小,FP16下误差可接受(约束较紧)
在昇腾NPU上可以调block_size吗?

可以,但要满足几个条件:

  1. block_size是32的倍数
    达芬奇架构的DataCopy指令要求32字节对齐。block_size=128(32×4)没问题,block_size=100(不是32的倍数)会报错。
  2. head_dim是128的倍数(通常成立)
    head_dim=128或256都没问题。head_dim=64的话,block_size=128可能太大(数值稳定性约束更紧)。
  3. seq_len是block_size的倍数
    FlashAttention要求seq_len能被block_size整除。你要是seq_len=2050,block_size=128,得先pad到2048或2176。
实际调优建议
# 默认block_size=128(推荐,数值稳定性最好)
output = npu_flash_attention(q, k, v, head_num=32, block_size=128)

# 如果你的seq_len很大(>4096),可以尝试block_size=256
# 但要检查输出误差(跟标准Attention对比)
output = npu_flash_attention(q, k, v, head_num=32, block_size=256)

# 如果你的seq_len很小(<1024),可以尝试block_size=64
# 减少分块数量,降低Kernel启动开销
output = npu_flash_attention(q, k, v, head_num=32, block_size=64)

⚠️ 踩坑预警:改block_size之后一定要验证正确性!用标准Attention的输出当基线,FP16的误差不能超过1e-3。

不同硬件的最优block_size可能不一样

我测了一组不同NPU型号的最优block_size(Llama-2-7B,seq_len=4096,FP16):

NPU型号 SRAM大小 默认block_size 最优block_size 原因
Ascend 910 64 MB 128 128 默认值最优,数值稳定性最好
Ascend 910B 96 MB 128 256 SRAM更大,可以放更大的分块
Ascend 310P3(推理卡) 32 MB 128 64 SRAM更小,分块大了放不下
A100 80GB 40 MB(L2 Cache) 128 128 跟Ascend 910类似
总结一下

FlashAttention的block_size=128不是随便选的,它是三个约束平衡出来的结果:

  • SRAM容量:128够小,能放下(通常不是瓶颈)
  • 计算强度:128够大,Cube Core利用率不错(再大好处递减)
  • 数值稳定性:128够小,FP16下误差可接受(主要约束)

实际建议:

  • 大多数情况用默认的128,不用调。
  • 如果你的NPU是Ascend 910B(SRAM更大),可以尝试256。
  • 如果你的NPU是Ascend 310P3(SRAM更小),建议降到64。
  • 改了block_size一定要验证正确性!

代码和文档:
https://atomgit.com/cann/ops-transformer

Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐