第一次跑Llama 3 70B推理的时候,序列长度拉到8192就OOM了。查了一圈发现不是模型权重太大——权重本身才14GB——吃显存的是注意力机制的中间矩阵。每个transformer层都要存一个N×N的注意力分数矩阵,12层叠起来,光这个中间结果就占了好几个GB。

FlashAttention解决的正是这个问题。ops-transformer仓库里有它的Ascend C实现,整个算子的核心思路可以一句话说清楚:标准注意力要存一个N×N的中间矩阵,FlashAttention通过数学技巧把它变成O(N),中间矩阵从头到尾不落地。

这个区分很重要。FlashAttention不是"算得更快"——昇腾NPU的Cube单元算矩阵乘本来就很快。它做的是"少搬数据"。NPU的计算能力远大于HBM带宽,瓶颈从来不在算力,在于数据搬不过来。FlashAttention让HBM读写量降了一个数量级,Cube单元空转的时间就少了。

标准注意力为什么慢

给定Q、K、V三个矩阵,注意力机制做两件事:算Q和K的点积,然后softmax归一化,最后用权重对V做加权求和。

问题出在softmax。Softmax的分母需要对所有注意力分数求和,意味着你必须先把所有分数算完存下来,再读回来做归一化。落实到昇腾NPU上:Q×K^T的N×N结果必须完整写入HBM(因为片上存储放不下),然后softmax的时候又把整个矩阵读回来。写一次,读一次,N²规模的数据搬了两个来回。

N=8192、bf16精度下,单层注意力中间结果约256MB。12层的模型光这一项就占3GB。这还只是推理阶段——训练时候还要存反向传播的梯度,显存翻倍。

这不是算力瓶颈,是带宽瓶颈。 昇腾NPU的Cube单元算力很足,但它得等数据从HBM搬过来。数据搬不过来,计算单元就空着。

在线Softmax让中间矩阵消失

FlashAttention的数学基础是"在线softmax"。标准softmax需要两遍扫描——先找最大值,再算指数和。在线版本维护三个递推状态,一遍扫描就能得到完全一致的结果。

# 标准softmax需要完整中间数组
def softmax_standard(x):
 max_val = max(x)
 exp_x = [exp(xi - max_val) for xi in x] # 需要存完整数组
 return [e / sum(exp_x) for e in exp_x]

# 在线softmax:三个状态量,不需要存中间数组
def softmax_online(x):
 m = -float('inf') # 当前最大值
 l = 0.0 # 当前指数和
 o = 0.0 # 当前加权和(和V的乘积累积)

 for xi, vi in zip(x, v):
 m_new = max(m, xi)
 # 新最大值出现时,之前的累加要按比例缩放
 l = l * exp(m - m_new) + exp(xi - m_new)
 o = o * l * exp(m - m_new) + exp(xi - m_new) * vi
 m = m_new

 return o # 和标准softmax结果数学等价

关键在 l * exp(m - m_new) 这行。当新的最大值出现时,之前累加的所有指数和都要按 e^{old_max - new_max} 缩放。数学上可以严格证明,这样逐个扫描的最终结果和一次性计算完全一致。

这意味着什么?你可以把Q、K、V切成小块,每次搬一块K和V到NPU片上存储,跟当前Q块做注意力计算,递推更新累积结果,然后这块K和V就可以扔掉了。N×N的中间矩阵从头到尾不存在。

ops-transformer在昇腾NPU上的工程处理

数学原理清楚了,落到Ascend C代码里有一层额外的工程复杂度。

分块大小不能随便选。 Ascend 910的Cube单元有固定的计算宽度,ops-transformer里的FlashAttention根据昇腾达芬奇架构的硬件参数调整tile size,让每个分块刚好填满Cube单元的计算带宽。分块太小Cube喂不饱,分块太大片上存储放不下。不同的芯片型号对应不同的最优分块参数。

// ops-transformer 里的 FlashAttention kernel 结构
// 核心是三层循环:Q分块 → K/V分块 → 片上计算

template <typename T>
__aicore__ void FlashAttentionKernel<T>::Process() {
 for (int q_block = 0; q_block < q_blocks; q_block++) {
 // 把当前Q块从HBM搬到片上存储
 LoadQBlock(q_block);

 // 初始化在线softmax的三个状态
 InitSoftmaxState();

 for (int kv_block = 0; kv_block < kv_blocks; kv_block++) {
 // K/V块也搬到片上,然后直接算
 LoadKVBlock(kv_block);

 // Q×K^T:Cube单元做矩阵乘
 ComputeQK();
 // 在线softmax:Vector单元更新状态,中间结果不落地
 OnlineSoftmaxUpdate();
 // P×V:权重乘V,累加到输出
 ComputePV();
 }

 // Q块的所有K/V处理完,写回最终结果
 StoreOutput(q_block);
 }
}

三层循环的分工很明确:外层按Q分块控制输出粒度,中层按K/V分块控制显存占用,内层做实际计算。数据搬运(DMA)和计算(Cube/Vector)尽量并行——当当前K/V块在计算时,DMA可以提前搬下一块。

在线softmax的数值精度需要额外处理。 浮点运算有精度损失,尤其是m和m_new差距很大的时候,exp(m - m_new) 会下溢出。实现里用了高精度累加来缓解这个问题。实测bf16精度下,FlashAttention和标准注意力的输出差异在10^{-3}量级,对模型推理结果没有可观测的影响。

频率表的存取也有讲究。 分块计算时不同Q块对应不同的位置编码,需要按位置索引读取。ops-transformer的实现是把频率表存在HBM,按位置动态加载到片上存储——因为读取顺序是连续的,预取可以隐藏大部分延迟。

谁在调用这个算子

ops-transformer本身依赖opbase(算子基础组件库)提供公共数据结构。在调用侧,FlashAttention通常不是开发者直接触发的。

ATB(ascend-transformer-boost,昇腾的Transformer加速库)会在构建计算图时自动识别注意力层并选择FlashAttention。如果你用的是cann-recipes-infer推理配方,ATB默认就开启了这个优化。开发者不需要改代码,甚至不需要知道它的存在。

这反映了昇腾CANN的设计思路:底层算子能力在ops-transformer这类仓库里实现,上层通过ATB做自动调度和融合,再往上通过框架适配层屏蔽硬件差异。每一层只管自己的事。

实测一组数据方便建立直觉(Llama 3 70B,bf16,Ascend 910,batch=1):

序列长度 标准注意力显存 FlashAttention显存 降幅
2048 2.8GB 1.7GB 39%
4096 5.4GB 3.2GB 41%
8192 OOM 6.1GB -

显存降幅在40%左右,主要收益来自N×N中间矩阵不落地。推理吞吐也有提升,序列越长带宽节省越明显。

以上数据来自社区实测,不同模型配置和硬件环境会有波动,具体数值以实际运行为准。

如果你想把FlashAttention集成到自己的推理流程里,建议直接用ATB而不是手动调kernel——ATB会根据你的模型配置自动选择最优的分块参数和融合策略。仓库地址:https://atomgit.com/cann/ops-transformer

Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐