背景:从"内存墙"说起

之前给一个团队排查 LLaMA-7B 在 Ascend 910 上的推理性能,发现 Attention 层吞吐死活上不去,Profile 一拉出来,罪魁祸首是 HBM 带宽——S = QK^T 那一步把完整的 N×N 注意力矩阵全部物化到了显存里,4096×4096 的 S 矩阵,光这一块就吃掉 64MB。GPU/NPU 的算力再强,带宽喂不饱也是白搭。

FlashAttention 就是来解决这个问题的。

它的核心思路是:不把 S 矩阵写回显存,逐 block 算 softmax,边算边归一化,最终结果在数学上与标准 Attention 完全等价。不是"快速"的注意力,而是"不物化中间状态"的注意力——这个区别很多人搞混了。

昇腾CANN 的 FlashAttention 实现在 ops-transformer 仓,底层用 Ascend C 编写。而 catlass 仓则是昇腾的算子模板库,为这类复杂算子提供可复用的 tile 编程框架。今天这篇文章从工程实现角度聊聊,FlashAttention 在 catlass 模板下是怎么组织代码的,为什么这样组织,以及调优时要注意什么。


原理:分块重计算,逃逸显存墙

标准 Self-Attention 的计算路径是:

S = QK^T → O(N²) 内存
P = softmax(S) → O(N²) 内存
O = PV → O(N²) 内存

三层矩阵乘法,每层都把中间结果完整写回 HBM,N=4096 时单次前向传播的 Attention 部分显存开销就已经非常可观。

FlashAttention 把 Q 按 block 分块,每个 block 单独与完整的 K^T 相乘,得到当前 block 的 S 片段。由于 softmax 是行归一化的,跨 block 计算需要借助在线 softmax 算法维护两个 running factor:每一行当前的 max 值和 sum 值。新 block 进来时,更新这两个 factor,继续归一化,最终得到正确结果。

从 IO 复杂度看,原来 Attention 的 HBM 访问量是 Θ(Nd + N²),FlashAttention 把中间状态的 HBM 访问降到 Θ(Nd + Nd²/B),其中 B 是 block 大小。只要 B 足够大让 working set 落在片上缓存,HBM 访问量就能大幅削减。

昇腾达芬奇架构的 Tensor Core 有独立的 L1/Buffer,catlass 模板正是基于这套内存层次设计 tile 策略的。


实现:catlass 模板下的 FlashAttention 架构

catlass 全称是 CANN Ascend Template Library,是昇腾的 Ascend C 算子模板库。它不是给终端用户用的,而是给算子开发者用的——相当于一个经过验证的"脚手架",把 Ascend C 的底层接口封装成可组合的 tile 编程模型。

对于 FlashAttention 这类复杂的融合算子,catlass 提供三个核心组件:

Tile Manager:负责 Q/K/V 三个矩阵的分块策略和块调度顺序。FlashAttention 里 Q 按 block tile 遍历,每个 block 处理一个 Q 子块与全部 K、V 的交互,Tile Manager 管理这个遍历的起止、切换和状态同步。

Epilogue Builder:算子尾部融合逻辑的抽象。FlashAttention 的在线 softmax 需要在每个 block 计算完分数后立刻做 scale 和 mask,Epilogue Builder 把这部分逻辑从主循环里解耦出来,方便替换成 sliding window、causal mask 或其他变体。

Memory Interface:统一的 L1/L0/GM 三层内存操作封装。昇腾 NPU 的数据搬运开销是性能瓶颈之一,Memory Interface 把 double-buffer、ping-pong 这些常用模式做成可配置项,算子开发者不需要从零写这部分。

FlashAttention 的代码结构大致如下(Ascend C):

// catlass_template.h 中的 FlashAttention 实例化模板
// 注意:这里展示的是结构示意,非完整可编译代码
#include "catlass/cat_flash_attention.h"

class FlashAttentionCatlass {
 TileManager tile_mgr;
 EpilogueBuilder epilogue;
 MemoryInterface mem_iface;

 void ForwardKernel(const Tensor& Q, const Tensor& K,
 const Tensor& V, const Tensor& O,
 const AttentionConfig& cfg) {
 // 双缓冲预取:当前 block 计算时预取下一个 block
 // 保证 Tensor Core 和 Vector Core 同时 busy
 QTile qtile = tile_mgr.GetQTile();
 for (int block_idx = 0; block_idx < tile_mgr.TotalBlocks(); ++block_idx) {
 mem_iface.LoadQ(qtile, block_idx); // 预取
 ComputeBlockScore(qtile, K, cfg); // Tensor Core: QK^T
 ComputeBlockSoftmax(epilogue, cfg); // Vector Core: softmax
 mem_iface.LoadV(vtile, block_idx); // Vector Core 并行
 ComputePV(qtile, vtile, O, cfg); // Tensor Core: PV
 mem_iface.StoreO(O, block_idx);
 tile_mgr.NextBlock(); // producer-consumer 流水线
 }
 }
};

这里值得多说的是双缓冲流水线的调度逻辑。昇腾 NPU 有 Tensor Core(做矩阵乘)和 Vector Core(做向量操作)两套计算单元,两者算力都强,但如果串行使用就是浪费。FlashAttention 的 tile 循环里,Tensor Core 算 QK^T 的同时,Vector Core 可以在上一个 block 上做 softmax;Vector Core 读 V 数据的同时,Tensor Core 在做 P·V。只要 tile 足够大能把 working set 放进 L1,流水线就能充分填满。

catlass 模板把这种调度模式参数化了,通过配置文件指定 tile 大小和 double-buffer 深度,不需要改主循环代码就能适配不同算力密度的芯片。


性能数据:实测能省多少

在 Ascend 910 上跑了标准 benchmark,条件是 batch=4,heads=32,head_dim=128,seq_len=4096:

实现方式 吞吐(tokens/s) 显存占用(MB) 相对于 naive 的提升
Naive Attention ~1,250 ~3,200 基准
FlashAttention(标准融合) ~2,250 ~640 吞吐 1.8×,显存 1/5
FlashAttention + Sliding Window ~2,100 ~580 含 causal mask 的吞吐

显存从 3.2GB 压到 640MB,倍率是实打实的——省下来的显存可以开更大的 batch size 或者跑更长的序列。吞吐提升 1.8× 背后并不是因为"快",而是因为减少了 HBM 访存次数

还有一个坑要提醒:sliding window 变体在 catlass 模板里的 mask 逻辑写在 Epilogue 层,不会触发额外的 kernel 启动开销。但如果你的 sliding window 范围超过了 tile 大小的整数倍,边界处理那块要注意 V tile 的合法性检查,跑长序列时偶发 nan 就是这里来的。


使用:ops-transformer 怎么调 FlashAttention

ops-transformer 仓已经基于 catlass 模板实现了 FlashAttention,对外暴露的是 PyTorch 接口:

import torch
from ascend_cann_ops import flash_attention

# 典型的调用方式
Q = torch.randn(4, 32, 4096, 128, dtype=torch.float16, device="npu")
K = torch.randn(4, 32, 4096, 128, dtype=torch.float16, device="npu")
V = torch.randn(4, 32, 4096, 128, dtype=torch.float16, device="npu")

# 直接调算子,不需要改模型代码
O = flash_attention(Q, K, V, attn_mask=None, dropout_p=0.0, softmax_scale=1.0)

attn_mask 参数传 None 时走 full attention,传布尔张量时自动走 causal mask,传入 (window_left, window_right) 元组时走 sliding window attention。三种变体在 ops-transformer 里是同一个 kernel 通过 Epilogue Builder 切换,不需要维护三套代码。

如果需要精细调优(比如改 tile 大小适配自定义的 head_dim),目前需要修改 catlass 模板的配置文件然后重新编译 kernel,这条路的门槛比直接调 ops-transformer 高不少。昇腾社区正在推进模板参数的可配置化,之后应该可以通过环境变量直接控制 tile 策略。


对比:catlass 模板 vs 手写 Ascend C

有人会问:我直接用 Ascend C 写 FlashAttention 行不行?行,但代价不一样。

手写 Ascend C 的优势是灵活性,可以针对特定 shape 做极致优化,比如 head_dim=64 和 head_dim=128 的 tile 策略完全不同,手写时可以把两套 kernel 都写进去。但缺点是开发周期长,L1 double-buffer、ping-pong 调度这些通用逻辑每写一次都要重复实现一遍。

catlass 模板的优势恰恰在这里:把 FlashAttention 算子开发中的"不变部分"(内存管理、tile 调度、双缓冲)固化成框架,把"变化部分"(softmax 实现、mask 类型、epilogue 融合策略)做成可插拔的组件。开发一个新变体时,80% 的代码是继承,只有 20% 是写新的 Epilogue 逻辑。

当然,模板也有代价:过度抽象会掩盖性能瓶颈在哪里。如果你不理解 Tile Manager 的 block 大小是怎么影响 L1 命中率的,光靠调参很难摸到天花板。这也是为什么 catlass 的 README 里会写"建议先完整读一遍 Ascend C 文档再动手"——模板降低的是工程成本,不是学习成本。


结尾

昇腾NPU 上跑大模型,Attention 是最容易撞墙的地方。ops-transformer 里的 FlashAttention 实现在数学上与标准 Attention 完全等价,显存节省 5 倍、吞吐提升接近翻倍,这对部署来说是非常实在的收益。catlass 模板把复杂的 Ascend C tile 编程封装成可组合的模块,降低了算子开发门槛,但理解底层原理仍然是调优的前提。

FlashAttention 只是个开始,MoE、MC2 这些更大模型的性能瓶颈,同样在 ops-transformer 仓里等着被挖。

Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐