调用 flash_attention() 时,CANN runtime 在干什么?
FlashAttention 的三个子算子之间有依赖(softmax 依赖 MatMul 的输出,第二个 MatMul 依赖 softmax 的输出),所以这条路径上 fusion 空间有限——但如果 attn_mask 为 None,编译器会在这个边界上尝试把 mask 补零操作吞进融合算子里,省掉一次独立的 kernel 启动。昇腾 NPU 上的 FlashAttention 之所以能跑出那样
你以为是一次调用,runtime 拆成了三次
大多数时候,大家用 PyTorch 写一行代码:
O = flash_attention(Q, K, V, attn_mask=None)
脑子里想象的画面是:进去了,出来了,就结束了。
但昇腾 NPU 上的 CANN runtime 收到这个调用时,它看到的不是"一次注意力计算",而是一张被拆解成多阶段执行流水线的算子图。FlashAttention 在这个图里是若干个子算子的组合,每一步的计算和数据搬运都由 runtime 的调度器统一管理。
这篇文章想拆掉这个黑箱。不是讲 FlashAttention 算法(那个已经很多人写过了),而是讲当算法落到昇腾 NPU 的 CANN runtime 上时,哪些地方被藏着,哪些地方其实你不得不理解。
第一层:runtime 怎么"看懂" FlashAttention
当模型用 PyTorch 写 attention 时,计算图里的节点是 torch.matmul / torch.softmax 这些高层算子,跟硬件没有直接对应关系。CANN runtime 的第一件事是把高层算子图翻译成昇腾 NPU 能执行的低层算子序列。
这个翻译层叫 Graph Compiler(图编译器),是 CANN 五层架构里"计算编译层"的核心组件。它干的事情有两部分:
算子拆解:FlashAttention 在 PyTorch 层是一个融合算子,Graph Compiler 把它拆成 MatMul(Q, K) → Softmax → MatMul(P, V) 三个子算子。这不是退步——拆解是为了让每个子算子都能独立做 tile 划分和内存调度。融合大 kernel 在 CPU/GPU 上有效,但在 NPU 上,如果 tile 策略不匹配达芬奇架构的 L1 大小,融合反而会让缓存局部性变差。
图优化:拆解之后,Graph Compiler 扫描整张图,看有没有可以合并的相邻节点。如果前后两个子算子之间没有依赖,可以把它们的执行交给同一个 task block,减少 kernel 启动开销。FlashAttention 的三个子算子之间有依赖(softmax 依赖 MatMul 的输出,第二个 MatMul 依赖 softmax 的输出),所以这条路径上 fusion 空间有限——但如果 attn_mask 为 None,编译器会在这个边界上尝试把 mask 补零操作吞进融合算子里,省掉一次独立的 kernel 启动。
关键认知:runtime 并不直接执行"FlashAttention",它执行的是一张拆解后经过优化的算子图。图的拓扑结构决定了数据在 NPU 各计算单元之间的流转路径。
第二层:内存布局,runtime 做的第一个关键决策
FlashAttention 的核心收益来自减少 HBM 访问,但"减少"不等于"不需要访问"——数据还是要从 HBM 加载到片上,在 Tensor Core 里算完,再写回去。
问题在于:数据以什么为单位、什么时机、以什么路径在 HBM 和片上之间搬运,这是 runtime 内存管理器的职责。
昇腾达芬奇架构的内存空间分三层:
- HBM(Global Memory):容量大,延迟高,带宽是瓶颈
- L1 Buffer:每个计算核独享,容量小(Ascend 910 上通常是 64KB),延迟低
- L0 Buffer:Tensor Core 直接访问的矩阵存储,容量比 L1 更小,但带宽最高
FlashAttention 的 tile 策略决定了 working set 的大小。tile 太小,working set 能全部放在 L1/L0,HBM 访问被压缩到最小;但 tile 太小意味着循环次数增加,loop overhead 变大。tile 太大,working set 溢出到 HBM,L1/L0 的命中率下降,每次迭代都要做一次完整的 DMA 搬运,效率反而退化。
CANN runtime 的内存管理器会基于当前芯片的 L1/L0 大小、tile 参数和输入 shape 自动计算最优 tile 大小。这个过程叫内存布局编排(Memory Layout Planning)。高级用户可以通过 asc-tools 里的配置参数手动覆盖自动决策:
# 手动指定 tile 策略,覆盖 runtime 的自动推断
import torch
from ascend_cann_ops import flash_attention
Q = torch.randn(4, 32, 4096, 128, dtype=torch.float16, device="npu")
K = torch.randn(4, 32, 4096, 128, dtype=torch.float16, device="npu")
V = torch.randn(4, 32, 4096, 128, dtype=torch.float16, device="npu")
# tile_level 参数控制 tile 大小:
# auto → runtime 根据 shape 和硬件自动选(默认)
# small → tile_size=32,适合 seq_len 短、batch 大的场景
# large → tile_size=128,适合 seq_len 长、head_dim 大的场景
O = flash_attention(Q, K, V, tile_level="auto")
# 一个常见判断逻辑
# seq_len >= 2048 + head_dim >= 64 → 设 large
# batch >= 8 + seq_len <= 1024 → 设 small
# 其余情况 → auto
一个常见的坑在这里:如果手动传了一个非对齐的 seq_len(比如 4097),tile 边界处理会变得复杂,内存管理器必须做 padding,working set 变大,缓存命中率下降。很多情况下 4096 跑出来比 4097 快,不是算法问题,是内存布局问题。可以在 trace 里看到 L1 miss rate 的差异。
第三层:任务调度,runtime 做的第二个关键决策
算子图有了,内存布局定了,接下来 runtime 要决定这些算子以什么顺序、在哪些计算核上执行。
这叫任务调度(Task Scheduling)。对 FlashAttention 来说,调度问题的核心是:三个子算子之间有数据依赖,但它们各自内部的多 tile 之间没有依赖。这意味着同一批 Q 数据上,第一个 MatMul 的 tile 0、tile 1、tile 2 之间可以并行,第二个 MatMul 的 tile 0 也已经可以在前者的 tile 0 完成后启动——只要数据流接得上。
CANN runtime 的调度器基于数据流图(Data Flow Graph)做依赖分析,生成一个拓扑序,然后在拓扑序上插入并行度:
MatMul_0 → Softmax_0 → MatMul_0'
MatMul_1 → Softmax_1 → MatMul_1'
MatMul_2 → Softmax_2 → MatMul_2'
三个 tile 并行跑,每条链内部有序,链之间通过 DMA 队列异步握手。如果调度器发现 L1 空间足够大,甚至会把 MatMul_n 和 Softmax_(n-1) 重叠执行——Tensor Core 算第 n 个 tile 的同时,Vector Core 跑第 n-1 个 tile 的 softmax,这是典型的算力与带宽流水线化。
ops-transformer 仓里的 FlashAttention 实现在这个调度层面做了定制:它把 tile_level 参数暴露出来,高级用户可以强制指定 tile 数量,控制并行度。默认值是 auto,runtime 根据输入 shape 和设备算力自动选。
# 不同 tile_level 对并行度的影响
# seq_len=4096, heads=32, head_dim=128, batch=4
# auto: runtime 自动推断,通常 tile_num=12(4×3)
# 每个 tile 处理 Q 的 4 个 block,K/V 的 3 个 block
O = flash_attention(Q, K, V, tile_level="auto")
# 设 tile_level="small": tile_num=24(4×6)
# 更多 tile,但每个 tile 的 working set 更小,L1 更友好
# 适合 batch 大、seq_len 长的场景
O = flash_attention(Q, K, V, tile_level="small")
# 设 tile_level="large": tile_num=6(2×3)
# 更少 tile,但每个 tile 的计算密度更高
# 适合 seq_len 很长、显存紧张、batch 小的场景
O = flash_attention(Q, K, V, tile_level="large")
但在某些场景(比如单核推理小 batch),手动设 tile=1 可以省掉额外的调度开销,反而更快。ops-transformer 内部会根据输入设备数和 batch 大小调整调度策略,用户不需要手动处理这个。
第四层:DMA 搬运,藏在调度背后的那只手
前面说了数据在 HBM 和片上之间来回搬运,但没说搬运是怎么发生的——这就是 DMA(Direct Memory Access)引擎做的事。
DMA 是 NPU 上的专用数据传输单元,不占用 Tensor Core 的算力。简单理解:Tensor Core 专注算,DMA 专注搬,两者互不阻塞。在 FlashAttention 的场景里,DMA 的调度策略直接决定了流水线能不能填满。
考虑这个时序:
t=0: DMA加载 Q[block0] → L1
t=1: Tensor Core 算 Q[block0] × K^T
t=2: Vector Core 算 softmax,同时 DMA 预取 K[block1]
t=3: Tensor Core 算 softmax_result × V,同时 DMA 加载 V[block1]
t=4: Tensor Core 写回 O[block0],同时启动下一轮的 DMA 加载
DMA 和计算重叠得越好,有效算力占比越高。如果 DMA 调度不及时,Tensor Core 算完当前 block 要等 DMA 加载下一个 block,流水线就断了。
CANN runtime 的 DMA 调度器会根据算子图的依赖关系预判下一块数据是什么时候需要的,提前下发 DMA 指令。这个预判基于 tile 的大小和当前硬件的 DMA 吞吐估算。这套机制叫异步流水线调度(Async Pipeline Scheduling),是 runtime 在幕后做的大量工作之一。
用 asc-tools 可以看到 DMA 等待时间的 trace:
# 跑一次 FlashAttention,执行完毕后查看每个阶段的耗时
$ python -m ascend_tools trace flash_attention_bench.py
# 输出示例(简化)
[MatMul(Q,K)] tensor_core: 1.23ms | dma_wait: 0.18ms
[Softmax ] vector_core: 0.56ms | dma_wait: 0.04ms
[MatMul(P,V)] tensor_core: 1.31ms | dma_wait: 0.21ms
# dma_wait 列就是 Tensor Core 等待 DMA 的时间
# 这个数字越小,说明 DMA 调度越及时,流水线越紧
# 如果 dma_wait > 0.5ms,说明 tile 大小设置不对,working set 不在 L1 里
ops-transformer 里的 FlashAttention 跑出 1.8× 的吞吐提升,DMA 调度优化贡献了相当比例——算子在片上跑得快是因为数据喂得及时,数据喂得及时是因为 DMA 调度器提前把活干完了。
第五层:kernel 启动开销,被低估的那块成本
前面都在讲算子执行本身,这里要提一个容易被忽略的环节:kernel 启动开销。
每一次算子执行,runtime 都要往调度器提交一个 task,调度器分配计算核、初始化 DMA、启动 Tensor Core。这个初始化过程不是免费的,大概在几微秒量级。对于大矩阵运算(比如 seq_len=4096 的 Attention),kernel 执行时间在毫秒级,几微秒的启动开销可以忽略。但对于 seq_len 较小(比如 256 或 512)的场景,kernel 执行时间只有几十微秒,启动开销可能占到 20%~30%。
Runtime 的优化手段是算子融合(Operator Fusion):把多个相邻的子算子合并成一个更大的 kernel,共享一次启动开销。FlashAttention 里的三个子算子如果能融合成一个大 kernel,kernel 启动从三次变成一次,这个比例就很可观了。
问题在于融合有条件:两个算子能融合的前提是它们之间没有同步点,而且融合后的 kernel 大小不超过硬件的寄存器文件和 L1 容量限制。FlashAttention 的三个子算子之间有数据依赖,Tensor Core 和 Vector Core 之间的结果传递需要同步点,这个同步点叫 barrier,它是融合的硬墙。
所以 CANN runtime 并不是把所有能融合的算子都融合了——有些边界条件决定了某些算子必须分开跑,即使融合在数学上可行。理解这一点,有助于判断手上的模型在昇腾 NPU 上有没有进一步优化的空间。
设计取舍:为什么 runtime 不做更多自动优化
一个自然的问题是:为什么 CANN runtime 不自动把 FlashAttention 的 tile 大小、DMA 调度、算子融合全部优化到最优?
答案是:信息的缺失。Runtime 不知道输入数据的语义分布——不知道 Q 里面哪些位置是 padding,哪些是真实 token。它也不知道这一轮 attention 之后,下一个 kernel 是什么,DMA 可以提前多久开始预取数据。Graph Compiler 在编译期做静态分析,但真实的运行时数据流只有到了执行那一刻才知道。
这跟 CPU/GPU 上的优化困境是一样的:编译器能做静态优化,但 runtime 的动态信息永远比编译期多。差距在于,CPU/GPU 的 runtime 优化(比如 CUDA stream 调度)已经高度成熟,而 CANN 的 runtime 生态还在快速演进——ops-transformer 仓里的实现在不断吸收来自社区的反馈,把更多运行时决策从手动调优迁移到自动优化。
这也是为什么 CANN 开源社区在推进 asc-tools 这类工具,让开发者能更直观地观察 runtime 做了什么、瓶颈在哪里。从"黑箱调用"到"可观测执行",是整个生态成熟度提升的信号。
结尾
调用 flash_attention() 的时候,你触发的是一整套经过精心编排的执行流水线:图编译器拆解算子,内存管理器规划 tile 布局,调度器分配并行度,DMA 引擎搬运数据,kernel 启动器初始化执行单元。每个环节都有优化空间,每个环节也都有它不得不这么做的约束。
昇腾 NPU 上的 FlashAttention 之所以能跑出那样的性能数字,不只是因为算法本身,更重要的是 CANN runtime 在底层把这些环节串成了一条高效的数据流。ops-transformer 仓里的实现只是这条链路的暴露端——真正发力的地方,藏在那些你看不见的调度决策里。
runtime 的优化是一个持续演进的方向。关注 CANN 的更新节奏,每次版本发布通常都会带来调度器和内存管理的改进——这些改进会直接体现在 ops-transformer 的 benchmark 里,不需要改一行代码。
更多推荐



所有评论(0)