CANN 算子拆解:FlashAttention 在 ops-transformer 里的实现逻辑
这篇文章深入解析了FlashAttention在昇腾NPU上的实现原理,指出常见的三大误区:1) 误将其视为单一算子而非融合策略;2) 忽视tiling参数对片上缓存利用的关键影响;3) 不了解在线Softmax和因果mask的优化机制。核心在于通过tiling分块、在线计算和跳过无效计算,确保中间结果始终驻留片上缓存,避免HBM访问瓶颈。文章特别强调默认配置已优化,但遇到性能问题需理解底层原理才
前言
上周有人在社区提了个 Issue:“为什么我在昇腾 NPU 上跑 FlashAttention,速度跟 PyTorch 原生 attention 差不多?”
我看了一眼他的代码,问题一目了然——他虽然 import 了 ATB 的
flash_attention,但传入的 tiling 参数是默认值,没按 NPU 的 Ub 缓存大小配置。融合算子确实在跑,但 tiling 不对的话,中间结果还是会溢出到 HBM,FlashAttention 等于白用了。这件事让我意识到:很多人把 FlashAttention 当黑盒用,跑通了就不管了,出了问题完全不知道从哪查。所以这篇文章不做教程,只做一件事——把 FlashAttention 在 ops-transformer 里的实现逻辑一层一层拆开,搞清楚每个部分在干什么,为什么这么干。
一、FlashAttention 不是"一个算子"
这是最常见的认知偏差。
FlashAttention 不是一个算子,是一个融合策略。
标准 Attention 是三个独立算子串行执行:
MatMul(Q×K)→ Softmax → MatMul(权重×V)
FlashAttention 的"融合"不是说把三个算子合并成一个大的算子,而是改变计算顺序和存储策略——让中间结果不再写回 HBM,直接留在 NPU 的片上缓存(Ub)里完成后续计算。
这个区别很重要:融合≠合并,融合=减少搬运。
所以当你看到 ops-transformer 源码里 FlashAttention 的实现,不要找"那个融合算子的代码"——它不是一个单独的 .cpp 文件,而是一套计算+存储的编排策略,横跨了 tiling、Softmax、因果 mask 三个子系统。
二、Tiling:不是切蛋糕,是拼拼图
Tiling 是 FlashAttention 里最容易被忽略、但最影响性能的部分。
认知纠偏: Tiling 不是"把大矩阵切成小块分别算"这么简单。如果只是切分再拼回去,那跟不 tiling 没区别——该搬 HBM 还是得搬。
Tiling 的真正目的是:让每一块 tile 的中间结果刚好能塞进 Ub,不溢出到 HBM。
昇腾达芬奇架构的 Ub 大小是固定的(每个 Core 约 64KB)。ops-transformer 里的 tiling 参数就是按这个容量倒推出来的:
Ub 容量 = 64KB(每个 Core)
FP16 精度下,一个 float16 元素占 2 字节
一个 128×128 的 tile = 128 × 128 × 2 = 32KB → 能塞进 Ub,还有余量放其他中间变量
所以 ops-transformer 默认 TILE_SIZE = 128
如果你在 ATB 里手动传了 tile_size=64,算子照样跑,但 Ub 空间利用率低——原来一块能装下的中间数据,现在要两块才能处理完,性能反而下降。
如果你的 FlashAttention 跑起来不快,第一个查的就是 tiling 参数。
三、在线 Softmax:一遍搞定,不回头
标准 Softmax 的计算步骤:
第一遍:扫描整个向量,找最大值 max_val
第二遍:用 max_val 做归一化,算 exp(x - max_val),求和
第三遍:除以 sum,得到最终概率
这在 GPU 上没问题,因为全局内存足够大,两遍扫描的中间结果可以随时回溯。
但 FlashAttention 要求所有计算都在 Ub 里完成——Ub 装不下整个 Softmax 向量,你没法"回头扫第二遍"。
ops-transformer 的解决方案:在线 Softmax(Online Softmax)。
核心思想是:一遍扫描,同时更新 max 和 sum,不需要回头。
// 简化版在线 Softmax 逻辑
float local_max = -INFINITY;
float local_sum = 0.0f;
for (int i = 0; i < tile_len; i++) {
float new_max = max(local_max, score[i]);
// 用新旧 max 的差值修正之前累加的 sum
local_sum *= exp(local_max - new_max);
local_sum += exp(score[i] - new_max);
local_max = new_max;
}
// 最终结果:exp(score - local_max) / local_sum
关键细节: local_sum *= exp(local_max - new_max) 这行——每次遇到更大的值,要把之前累加的 sum 做一次缩放修正。这个修正保证了最终结果的数值精度跟标准两遍 Softmax 等价。
这也是为什么 ops-transformer 的 FlashAttention 在精度上能跟 PyTorch 原生实现对齐——不是近似,是数学等价的另一种计算顺序。
四、因果 Mask:不算比算了再扔更快
大模型推理是自回归的——每个 token 只能看到之前的 token,不能偷看未来。
传统做法:先算完整的 attention 分数矩阵,再用一个下三角 mask 把"未来"位置置零。
ops-transformer 的做法更聪明:在 tiling 的时候直接跳过不需要算的 tile。
假设序列长度 2048,tile 大小 128
标准做法:16×16 = 256 个 tile 全算,再做 mask → 浪费了上三角的 120 个 tile
ops-transformer:只算下三角的 136 个 tile → 节省 47% 的计算量
这跟"算了再 mask"的区别:一个是做了无用功再扔掉,一个是压根不做。后者的计算量直接砍半,而且是跟序列长度成二次方关系——越长省得越多。
五、和 ascend-transformer-boost 的分工
搞清楚了 ops-transformer 的实现逻辑,再看 ATB 就清晰了:
ATB(调度层)
→ 决定用什么融合策略
→ 管理 tiling 参数配置
→ 多算子之间的协同调度
ops-transformer(实现层)
→ FlashAttention 的具体计算逻辑
→ Tiling / 在线 Softmax / 因果 mask 的硬件级实现
opbase(基础层)
→ 通用算子组件,所有 ops-* 仓库共享
ATB 的 flash_attention() 接口帮你配好了 tiling、开好了因果 mask、选好了融合策略。如果你只用默认配置,完全不用碰 ops-transformer。
但如果你需要:
- 自定义 tiling 大小(适配特殊序列长度)
- 修改在线 Softmax 的精度策略
- 调整因果 mask 的实现方式
那就得进 ops-transformer 的源码改。
总结:一句话说就是
FlashAttention 在 ops-transformer 里的实现拆开来看就三件事:Tiling 按 Ub 容量分块、在线 Softmax 一遍扫描、因果 mask 跳过不必要计算。三者配合,核心目标就一个——让中间结果不离开片上缓存。
ATB 是默认配置的一键开关,ops-transformer 是手动挡——自动挡够用就别换手动,但出了问题你得知道手动挡的原理才能查。
更多推荐




所有评论(0)