CANN ops-transformer FlashAttention 里的因果掩码:分块计算时怎么防止“偷看未来“
有个问题困扰了我一阵子:昇腾CANN 的 ops-transformer 仓库里 FlashAttention 分块做在线 Softmax 的时候,因果掩码(causal mask)怎么处理?分块算 QK^T,你只有局部数据,怎么知道哪些位置应该被遮掉?翻了一遍代码,发现实现方式和我想的完全不一样。昇腾NPU 上 FlashAttention 的 causal mask 处理是直接在分块级别做遮挡
有个问题困扰了我一阵子:昇腾CANN 的 ops-transformer 仓库里 FlashAttention 分块做在线 Softmax 的时候,因果掩码(causal mask)怎么处理?分块算 QK^T,你只有局部数据,怎么知道哪些位置应该被遮掉?
翻了一遍代码,发现实现方式和我想的完全不一样。昇腾NPU 上 FlashAttention 的 causal mask 处理是直接在分块级别做遮挡,不是先算完再遮。ops-transformer 这个仓库是昇腾CANN 里 Transformer 类大模型进阶算子库,FlashAttention 只是其中最出名的那个。
❌ 先纠正一个常见误解
很多人以为 causal mask 就是"算完注意力矩阵,再把未来位置设成零"。
这个理解在标准注意力里没问题——你确实可以先算完整矩阵,再 mask。但 FlashAttention 不存完整矩阵。分块算 QK^T,每个分块算完就做在线 Softmax 修正,然后直接乘 V 输出。中间的注意力矩阵从不完整存在于显存里。
那你怎么 mask?你不能对一个不存在的东西做遮罩操作。
因果掩码必须在分块计算的时候就生效。不是后处理,是前置条件。
🔍 分块层面的 causal mask 怎么做
FlashAttention 的分块策略把 Q 按行分成 M 个块,K 按列分成 N 个块。形成 M×N 的分块网格。
因果掩码的规则很简单:token i 只能看到 token 0 到 i。在分块网格上,这意味着左下三角的分块完全被遮掉,右上三角的分块完全可见,对角线附近的分块部分可见。
具体处理方式:
完全遮挡的分块(Q块行号 > K块列号):直接跳过,不计算。这个分块的 QK^T 结果全是 -inf,对 Softmax 和最终输出没有任何贡献。跳过比算出来再遮要快——算一个全零分块是在浪费 Cube 单元时间。
完全可见的分块(Q块行号 ≤ K块列号且不跨越边界):正常计算,不做任何 mask 处理。
部分可见的分块(对角线附近):这是最麻烦的情况。一个分块内部,有些行能看到所有列,有些行只能看到部分列。需要在分块内部应用行级掩码。
c复制
// 分块网格遍历,causal模式只算有效分块
for (int br = 0; br < blocks_m; br++) {
float row_max = -INF;
float row_sum = 0.0;
// 为什么K块只到br?因为causal规则:Q行i只能看K列0~i
// br之后的K块整块跳过,省掉无意义计算
for (int bc = 0; bc <= br; bc++) {
auto s_block = cube_matmul(Q[br], K[bc]);
if (br == bc) {
// 对角线分块:内部做行级mask
// 第r行只能看到0~r列,超出的位置设-inf
apply_row_causal_mask(s_block, br * block_size);
}
// br < bc 时整块可见,不需要mask
// br > bc 的分块直接跳过了,循环条件已经排掉
row_max = update_max(row_max, s_block);
row_sum = update_sum(row_max, s_block);
}
write_final_output(br);
}
关键逻辑:bc <= br 这个循环条件就是 causal mask 在分块层面的体现。不是算完再遮,是压根就不算无效分块。
💡 对角线分块内部的行级掩码
最微妙的部分是对角线分块的处理。假设 block_size=64,分块 (br=2, bc=2) 对应 token 128~191 的 Q 和 token 128~191 的 K。
在这个分块内部:
- token 128(行0)只能看到 token 128(列0)
- token 129(行1)能看到 token 128~129(列0~1)
- token 191(行63)能看到 token 128~191(列0~63)
这是一个逐步展开的三角形遮罩,每行的遮罩边界不同。ops-transformer 用 Vector 单元在 Softmax 之前对每行做掩码——把超出边界的位置设为 -inf,Softmax 自然会把它们归零。
这个操作在片上缓存里完成,不需要额外显存。掩码数据本身很小(一个 block_size×block_size 的布尔矩阵),和 QK^T 分块一起留在 Cube/Vector 的片上缓存里。
📊 causal mask 对性能的影响
跳过无效分块直接省计算量。序列长度 4K,block_size=64 的时候:
| 模式 | 有效分块数 | 总分块数 | 节省比例 |
|---|---|---|---|
| 无 mask(双向注意力) | 64×64 = 4096 | 4096 | 0% |
| causal mask | 约 2080 | 4096 | 49% |
接近一半的分块直接跳过了。这不只是省显存——是省了 49% 的矩阵乘计算量。序列越长,节省比例越大。
128K 序列长度的时候,causal mask 节省约 98% 的无效计算。这就是为什么 FlashAttention + causal 在长序列场景上快得离谱——标准注意力算完了再遮,浪费的计算永远收不回来。
ops-transformer 的其他掩码支持
ops-transformer 的 FlashAttention 不只支持 causal:
- causal mask——自回归生成,不能看未来
- local mask——滑动窗口注意力,token i 只看前后 W 个 token,长上下文模型常用
- custom mask——用户自定义遮罩,支持任意 pattern
local mask 的分块处理和 causal 类似,但判断逻辑更复杂——每个分块需要看窗口范围是否覆盖。custom mask 则需要用户提供一个 mask tensor,在分块内做查找。
这些掩码都在分块层面前置处理,不是后处理。原则一样:不计算无效分块。
下一步
- 如果你做自回归模型推理,确认 FlashAttention 用的是 causal 模式——双向注意力在生成场景是错误的
- 去看 ops-transformer 仓库里 FlashAttention kernel 的 causal 分支实现,循环条件
bc <= br就是核心 - 长序列场景注意 local mask 的窗口设置,直接影响计算量和显存
- 推理部署走 ATB 接口,ATB 内部已经根据模型类型自动选择 mask 模式
更多推荐

所有评论(0)