有个问题困扰了我一阵子:昇腾CANN 的 ops-transformer 仓库里 FlashAttention 分块做在线 Softmax 的时候,因果掩码(causal mask)怎么处理?分块算 QK^T,你只有局部数据,怎么知道哪些位置应该被遮掉?

翻了一遍代码,发现实现方式和我想的完全不一样。昇腾NPU 上 FlashAttention 的 causal mask 处理是直接在分块级别做遮挡,不是先算完再遮。ops-transformer 这个仓库是昇腾CANN 里 Transformer 类大模型进阶算子库,FlashAttention 只是其中最出名的那个。

❌ 先纠正一个常见误解

很多人以为 causal mask 就是"算完注意力矩阵,再把未来位置设成零"。

这个理解在标准注意力里没问题——你确实可以先算完整矩阵,再 mask。但 FlashAttention 不存完整矩阵。分块算 QK^T,每个分块算完就做在线 Softmax 修正,然后直接乘 V 输出。中间的注意力矩阵从不完整存在于显存里

那你怎么 mask?你不能对一个不存在的东西做遮罩操作。

因果掩码必须在分块计算的时候就生效。不是后处理,是前置条件。

🔍 分块层面的 causal mask 怎么做

FlashAttention 的分块策略把 Q 按行分成 M 个块,K 按列分成 N 个块。形成 M×N 的分块网格。

因果掩码的规则很简单:token i 只能看到 token 0 到 i。在分块网格上,这意味着左下三角的分块完全被遮掉,右上三角的分块完全可见,对角线附近的分块部分可见。

具体处理方式:

完全遮挡的分块(Q块行号 > K块列号):直接跳过,不计算。这个分块的 QK^T 结果全是 -inf,对 Softmax 和最终输出没有任何贡献。跳过比算出来再遮要快——算一个全零分块是在浪费 Cube 单元时间。

完全可见的分块(Q块行号 ≤ K块列号且不跨越边界):正常计算,不做任何 mask 处理。

部分可见的分块(对角线附近):这是最麻烦的情况。一个分块内部,有些行能看到所有列,有些行只能看到部分列。需要在分块内部应用行级掩码。

c复制

// 分块网格遍历,causal模式只算有效分块
for (int br = 0; br < blocks_m; br++) {
 float row_max = -INF;
 float row_sum = 0.0;
 // 为什么K块只到br?因为causal规则:Q行i只能看K列0~i
 // br之后的K块整块跳过,省掉无意义计算
 for (int bc = 0; bc <= br; bc++) {
 auto s_block = cube_matmul(Q[br], K[bc]);
 if (br == bc) {
 // 对角线分块:内部做行级mask
 // 第r行只能看到0~r列,超出的位置设-inf
 apply_row_causal_mask(s_block, br * block_size);
 }
 // br < bc 时整块可见,不需要mask
 // br > bc 的分块直接跳过了,循环条件已经排掉
 
 row_max = update_max(row_max, s_block);
 row_sum = update_sum(row_max, s_block);
 }
 write_final_output(br);
}

关键逻辑:bc <= br 这个循环条件就是 causal mask 在分块层面的体现。不是算完再遮,是压根就不算无效分块

💡 对角线分块内部的行级掩码

最微妙的部分是对角线分块的处理。假设 block_size=64,分块 (br=2, bc=2) 对应 token 128~191 的 Q 和 token 128~191 的 K。

在这个分块内部:

  • token 128(行0)只能看到 token 128(列0)
  • token 129(行1)能看到 token 128~129(列0~1)
  • token 191(行63)能看到 token 128~191(列0~63)

这是一个逐步展开的三角形遮罩,每行的遮罩边界不同。ops-transformer 用 Vector 单元在 Softmax 之前对每行做掩码——把超出边界的位置设为 -inf,Softmax 自然会把它们归零。

这个操作在片上缓存里完成,不需要额外显存。掩码数据本身很小(一个 block_size×block_size 的布尔矩阵),和 QK^T 分块一起留在 Cube/Vector 的片上缓存里。

📊 causal mask 对性能的影响

跳过无效分块直接省计算量。序列长度 4K,block_size=64 的时候:

模式 有效分块数 总分块数 节省比例
无 mask(双向注意力) 64×64 = 4096 4096 0%
causal mask 约 2080 4096 49%

接近一半的分块直接跳过了。这不只是省显存——是省了 49% 的矩阵乘计算量。序列越长,节省比例越大。

128K 序列长度的时候,causal mask 节省约 98% 的无效计算。这就是为什么 FlashAttention + causal 在长序列场景上快得离谱——标准注意力算完了再遮,浪费的计算永远收不回来。

ops-transformer 的其他掩码支持

ops-transformer 的 FlashAttention 不只支持 causal:

  • causal mask——自回归生成,不能看未来
  • local mask——滑动窗口注意力,token i 只看前后 W 个 token,长上下文模型常用
  • custom mask——用户自定义遮罩,支持任意 pattern

local mask 的分块处理和 causal 类似,但判断逻辑更复杂——每个分块需要看窗口范围是否覆盖。custom mask 则需要用户提供一个 mask tensor,在分块内做查找。

这些掩码都在分块层面前置处理,不是后处理。原则一样:不计算无效分块

下一步

  1. 如果你做自回归模型推理,确认 FlashAttention 用的是 causal 模式——双向注意力在生成场景是错误的
  2. 去看 ops-transformer 仓库里 FlashAttention kernel 的 causal 分支实现,循环条件 bc <= br 就是核心
  3. 长序列场景注意 local mask 的窗口设置,直接影响计算量和显存
  4. 推理部署走 ATB 接口,ATB 内部已经根据模型类型自动选择 mask 模式

仓库地址:https://atomgit.com/cann/ops-transformer

Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐