CANN FlashAttention 不是算力加速器,它在昇腾NPU上做的是“偷懒“的数学
第一次跑Llama 3 70B推理的时候,序列长度拉到8192就OOM了。查了一圈发现不是模型权重太大——权重本身才14GB——吃显存的是注意力机制的中间矩阵。每个transformer层都要存一个N×N的注意力分数矩阵,12层叠起来,光这个中间结果就占了好几个GB。FlashAttention解决的正是这个问题。这个区分很重要。FlashAttention不是"算得更快"——昇腾NPU的Cube
第一次跑Llama 3 70B推理的时候,序列长度拉到8192就OOM了。查了一圈发现不是模型权重太大——权重本身才14GB——吃显存的是注意力机制的中间矩阵。每个transformer层都要存一个N×N的注意力分数矩阵,12层叠起来,光这个中间结果就占了好几个GB。
FlashAttention解决的正是这个问题。ops-transformer仓库里有它的Ascend C实现,整个算子的核心思路可以一句话说清楚:标准注意力要存一个N×N的中间矩阵,FlashAttention通过数学技巧把它变成O(N),中间矩阵从头到尾不落地。
这个区分很重要。FlashAttention不是"算得更快"——昇腾NPU的Cube单元算矩阵乘本来就很快。它做的是"少搬数据"。NPU的计算能力远大于HBM带宽,瓶颈从来不在算力,在于数据搬不过来。FlashAttention让HBM读写量降了一个数量级,Cube单元空转的时间就少了。
标准注意力为什么慢
给定Q、K、V三个矩阵,注意力机制做两件事:算Q和K的点积,然后softmax归一化,最后用权重对V做加权求和。
问题出在softmax。Softmax的分母需要对所有注意力分数求和,意味着你必须先把所有分数算完存下来,再读回来做归一化。落实到昇腾NPU上:Q×K^T的N×N结果必须完整写入HBM(因为片上存储放不下),然后softmax的时候又把整个矩阵读回来。写一次,读一次,N²规模的数据搬了两个来回。
N=8192、bf16精度下,单层注意力中间结果约256MB。12层的模型光这一项就占3GB。这还只是推理阶段——训练时候还要存反向传播的梯度,显存翻倍。
这不是算力瓶颈,是带宽瓶颈。 昇腾NPU的Cube单元算力很足,但它得等数据从HBM搬过来。数据搬不过来,计算单元就空着。
在线Softmax让中间矩阵消失
FlashAttention的数学基础是"在线softmax"。标准softmax需要两遍扫描——先找最大值,再算指数和。在线版本维护三个递推状态,一遍扫描就能得到完全一致的结果。
# 标准softmax需要完整中间数组
def softmax_standard(x):
max_val = max(x)
exp_x = [exp(xi - max_val) for xi in x] # 需要存完整数组
return [e / sum(exp_x) for e in exp_x]
# 在线softmax:三个状态量,不需要存中间数组
def softmax_online(x):
m = -float('inf') # 当前最大值
l = 0.0 # 当前指数和
o = 0.0 # 当前加权和(和V的乘积累积)
for xi, vi in zip(x, v):
m_new = max(m, xi)
# 新最大值出现时,之前的累加要按比例缩放
l = l * exp(m - m_new) + exp(xi - m_new)
o = o * l * exp(m - m_new) + exp(xi - m_new) * vi
m = m_new
return o # 和标准softmax结果数学等价
关键在 l * exp(m - m_new) 这行。当新的最大值出现时,之前累加的所有指数和都要按 e^{old_max - new_max} 缩放。数学上可以严格证明,这样逐个扫描的最终结果和一次性计算完全一致。
这意味着什么?你可以把Q、K、V切成小块,每次搬一块K和V到NPU片上存储,跟当前Q块做注意力计算,递推更新累积结果,然后这块K和V就可以扔掉了。N×N的中间矩阵从头到尾不存在。
ops-transformer在昇腾NPU上的工程处理
数学原理清楚了,落到Ascend C代码里有一层额外的工程复杂度。
分块大小不能随便选。 Ascend 910的Cube单元有固定的计算宽度,ops-transformer里的FlashAttention根据昇腾达芬奇架构的硬件参数调整tile size,让每个分块刚好填满Cube单元的计算带宽。分块太小Cube喂不饱,分块太大片上存储放不下。不同的芯片型号对应不同的最优分块参数。
// ops-transformer 里的 FlashAttention kernel 结构
// 核心是三层循环:Q分块 → K/V分块 → 片上计算
template <typename T>
__aicore__ void FlashAttentionKernel<T>::Process() {
for (int q_block = 0; q_block < q_blocks; q_block++) {
// 把当前Q块从HBM搬到片上存储
LoadQBlock(q_block);
// 初始化在线softmax的三个状态
InitSoftmaxState();
for (int kv_block = 0; kv_block < kv_blocks; kv_block++) {
// K/V块也搬到片上,然后直接算
LoadKVBlock(kv_block);
// Q×K^T:Cube单元做矩阵乘
ComputeQK();
// 在线softmax:Vector单元更新状态,中间结果不落地
OnlineSoftmaxUpdate();
// P×V:权重乘V,累加到输出
ComputePV();
}
// Q块的所有K/V处理完,写回最终结果
StoreOutput(q_block);
}
}
三层循环的分工很明确:外层按Q分块控制输出粒度,中层按K/V分块控制显存占用,内层做实际计算。数据搬运(DMA)和计算(Cube/Vector)尽量并行——当当前K/V块在计算时,DMA可以提前搬下一块。
在线softmax的数值精度需要额外处理。 浮点运算有精度损失,尤其是m和m_new差距很大的时候,exp(m - m_new) 会下溢出。实现里用了高精度累加来缓解这个问题。实测bf16精度下,FlashAttention和标准注意力的输出差异在10^{-3}量级,对模型推理结果没有可观测的影响。
频率表的存取也有讲究。 分块计算时不同Q块对应不同的位置编码,需要按位置索引读取。ops-transformer的实现是把频率表存在HBM,按位置动态加载到片上存储——因为读取顺序是连续的,预取可以隐藏大部分延迟。
谁在调用这个算子
ops-transformer本身依赖opbase(算子基础组件库)提供公共数据结构。在调用侧,FlashAttention通常不是开发者直接触发的。
ATB(ascend-transformer-boost,昇腾的Transformer加速库)会在构建计算图时自动识别注意力层并选择FlashAttention。如果你用的是cann-recipes-infer推理配方,ATB默认就开启了这个优化。开发者不需要改代码,甚至不需要知道它的存在。
这反映了昇腾CANN的设计思路:底层算子能力在ops-transformer这类仓库里实现,上层通过ATB做自动调度和融合,再往上通过框架适配层屏蔽硬件差异。每一层只管自己的事。
实测一组数据方便建立直觉(Llama 3 70B,bf16,Ascend 910,batch=1):
| 序列长度 | 标准注意力显存 | FlashAttention显存 | 降幅 |
|---|---|---|---|
| 2048 | 2.8GB | 1.7GB | 39% |
| 4096 | 5.4GB | 3.2GB | 41% |
| 8192 | OOM | 6.1GB | - |
显存降幅在40%左右,主要收益来自N×N中间矩阵不落地。推理吞吐也有提升,序列越长带宽节省越明显。
以上数据来自社区实测,不同模型配置和硬件环境会有波动,具体数值以实际运行为准。
如果你想把FlashAttention集成到自己的推理流程里,建议直接用ATB而不是手动调kernel——ATB会根据你的模型配置自动选择最优的分块参数和融合策略。仓库地址:https://atomgit.com/cann/ops-transformer
更多推荐




所有评论(0)