前言

写这篇文章的时候,我刚帮一个学弟 debug 了一个奇怪的问题——同样的 LLaMA 模型,在 GPU 上推理延迟 800ms,搬到昇腾 NPU 上反而变成了 2.3 秒。他以为是昇腾硬件不行,差点换回 GPU。

问题出在哪?他没有用 FlashAttention 融合算子,直接跑的 PyTorch 原生实现。昇腾 NPU 的算力其实比同价位的 GPU 强不少,但标准 Attention 的数据搬运模式在 NPU 上吃不开——片上缓存(Ub)的优势完全没发挥出来。

加上 ops-transformer 的 FlashAttention 之后,延迟从 2.3 秒降到了 650ms。比 GPU 还快。

先搞清楚一个问题:标准 Attention 为什么慢?

标准 Attention 的计算过程,拆开来看就是三步:

第①步:Q 乘 K → 得到一个分数矩阵
第②步:对分数矩阵做 Softmax → 变成概率权重
第③步:权重乘 V → 得到最终输出

听起来很简单对吧?问题出在搬运数据上

打个比方:你在客厅(NPU 片上缓存,容量小但速度极快)做数学题,但草稿纸放在卧室(HBM 高带宽内存,容量大但取一趟要时间)。

标准 Attention 的做法是:

  1. 在客厅算 QK^T → 结果写回卧室
  2. 从卧室拿出来做 Softmax → 结果再写回卧室
  3. 从卧室再拿出来乘 V → 最终结果写回卧室

每一"步"都要跑一趟卧室。 对短序列来说还好,一旦序列长度到 4096、8192,那趟路跑得次数多了,时间全花在路上了——计算根本没成为瓶颈,搬运才是。

这就是 FlashAttention 要解决的问题。


FlashAttention 的核心思路:别往卧室跑了

FlashAttention 的想法粗暴而有效:算完不放回卧室,直接留在客厅继续算。

具体来说:

  1. QK^T 算完 → 结果留在客厅的 Ub(昇腾达芬奇架构的片上缓存)
  2. Softmax 直接在客厅做 → 结果还留在客厅
  3. 权重乘 V → 也在客厅完成 → 只把最终结果写回卧室

一趟卧室都不用跑(中间结果不需要),搬运开销直接砍掉 2/3。

这就是 FlashAttention 在昇腾 NPU 上快 3 倍的根本原因:不是算得更快了,是搬得更少了。


ops-transformer 里的 FlashAttention 长什么样?

在 CANN 的生态里,FlashAttention 的实现藏在 ops-transformer 仓库里——这是 CANN 第 2 层(算子服务层)的 Transformer 类算子库,专门做大模型相关的进阶算子。

ops-transformer 不是"一个算子",是一堆算子的集合:FlashAttention、MoE Router、MC2 等等。FlashAttention 是其中使用频率最高的一个。

从代码层面看,ops-transformer 的 FlashAttention 跟学术界的原始论文有些不一样:

① Tiling 策略是手动调过的

论文里的 FlashAttention 用的是通用的分块大小,但昇腾达芬奇架构的 Ub 大小是固定的(64KB per Core),所以 ops-transformer 里的 tiling 参数是针对这个硬件手动调优过的。

// 这里的 TILE_SIZE 不是拍脑袋选的
// 是按 Ub 容量 / 数据精度算出来的最优值
// 对 FP16 来说,128x128 是 Ascend 910 上的甜点位
constexpr int TILE_SIZE = 128;

② Softmax 用了在线算法

标准 Softmax 需要两遍扫描:先扫一遍找最大值,再扫一遍算 exp 求和。但 FlashAttention 要求所有东西都在片上缓存完成,不能回 HBM。所以 ops-transformer 用了在线 Softmax 算法——一遍扫描就搞定,中间结果存在寄存器里,不需要回头。

③ 因果 Mask(Causal Mask)直接编码进计算

大模型推理用的都是自回归模式,attention 只能看到当前 token 之前的 token。传统做法是先算完所有 attention 分数,再用一个 mask 矩阵把"未来"的部分置零。ops-transformer 的做法更直接——在分块计算的时候,直接跳过那些不需要算的位置,连算都不算,白省计算量。


和 ascend-transformer-boost 的关系

这里容易搞混一个事:ops-transformer 和 ascend-transformer-boost(ATB)是什么关系?

简单说:

  • ops-transformer 是算子实现层——底层的 Ascend C 代码写在这里,控制每一个计算细节
  • ascend-transformer-boost(ATB) 是上层加速库——把 ops-transformer 里的算子封装成易用的 Python/C++ API,同时做多算子融合调度

打个比方:ops-transformer 是厨房里的菜谱和食材,ATB 是外卖平台——你点一份"FlashAttention 套餐",ATB 帮你协调 ops-transformer 里的各个算子,自动完成融合和调度。

所以你在用 ATB 的 flash_attention() 接口时,底层调用的就是 ops-transformer 里的算子实现。


不同版本里的变化

ops-transformer 的 FlashAttention 跟着 CANN 版本一起迭代:

CANN 版本 FlashAttention 变化
8.0 之前 基础 FlashAttention,支持 FP16,固定 tiling
CANN 8.0 加入 MoE 融合、GQA(分组查询注意力)适配
CANN 8.5 支持 BF16 精度、滑动窗口 attention、更长序列优化
全面开源后 完整代码可在 AtomGit 查看,社区可贡献

如果你在用 CANN 8.0 之前的版本,建议升级——8.0 之后的 FlashAttention 算子融合策略做了很大改进,长序列场景下性能提升明显。


一张图看懂整个流程

用户调用 ATB / PyTorch
 ↓
ascend-transformer-boost(加速库,负责调度)
 ↓
ops-transformer(算子实现层)
 ↓
opbase(通用算子基础组件)
 ↓
Ascend C → 昇腾 NPU(达芬奇架构,Ub 片上缓存)

从用户的一行 Python 代码,到 NPU 上的实际计算,中间经过了四层。ops-transformer 处在第三层,是"把算术逻辑翻译成硬件能执行的操作"那个环节。


总结:一句话说就是

FlashAttention 的本质不是"更快地算",而是"更少地搬"——把中间计算结果留在 NPU 的片上缓存(Ub)里,省掉反复读写 HBM 的开销。

ops-transformer 是这个算子在昇腾 CANN 里的底层实现,针对达芬奇架构做了 tiling、在线 Softmax、因果 mask 等硬件级优化。ATB 在它上面做了一层易用封装,让开发者不用关心底层细节就能直接用。

三件事记住了就够:别往 HBM 搬东西、tiling 要针对硬件调、算子融合让 ATB 来调度。

Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐