为什么 Attention 是瓶颈?

先回顾一下问题本身。标准 Self-Attention 的计算过程:

Q, K, V = Linear(x)         # 投影
S = Q @ K^T                  # 注意力分数
P = Softmax(S)               # 归一化
O = P @ V                    # 加权求和

看起来就四步,但问题出在显存访问上。Q、K、V 的 shape 是 [batch, heads, seq_len, dim],当 seq_len 到 8192 甚至更长的时候,中间矩阵 S 的 shape 是 [batch, heads, seq_len, seq_len],这个矩阵大得离谱。以 LLaMA 13B 为例,32 个注意力头,seq_len=8192,S 矩阵光是 FP16 就要占 32GB 显存,根本放不下。

而且这个 S 矩阵算完 Softmax 之后还要跟 V 做矩阵乘法,意味着要再读一遍。来回读写 HBM(显存)的带宽就成了瓶颈。

FlashAttention 的核心思路:不分步计算,把 Attention 整个流程放在片上 SRAM 里完成,避免中间结果写回 HBM。

听起来简单,做起来要处理两个问题:Softmax 的在线计算(因为不知道全局最大值没法直接算 Softmax)和分块策略(SRAM 容量有限,得分块处理)。

标准实现 vs IO-Aware 实现

先看标准实现的问题在哪。

标准实现(Naive Attention):

import torch
import torch.nn.functional as F

def naive_attention(query, key, value):
    """标准 Self-Attention,中间结果全部落回 HBM"""
    d_k = query.size(-1)
    # Q @ K^T 产生 [batch, heads, seq_len, seq_len] 的巨大矩阵
    scores = torch.matmul(query, key.transpose(-2, -1)) / (d_k ** 0.5)
    # Softmax 结果也要写回 HBM
    p_attn = F.softmax(scores, dim=-1)
    # 再读一遍 p_attn,跟 V 做矩阵乘
    return torch.matmul(p_attn, value)

4 次 HBM 读 + 4 次 HBM 写,中间矩阵 S 和 P 都要落回显存。seq_len 一大,HBM 带宽直接被撑爆。

FlashAttention 实现(基于 ops-transformer 的调用方式):

import torch
import torch_npu
from ops_transformer import flash_attention

def flash_attention_inference(query, key, value, seq_len, head_dim):
    """调用 ops-transformer 的 FlashAttention 算子
    分块计算,中间 Softmax 结果留在 UB(片上 SRAM),不写回 HBM"""
    # query/key/value: [batch, num_heads, seq_len, head_dim]
    attn_output = flash_attention.flash_attention_score(
        query,
        key,
        value,
        drop_mask=None,
        padding_mask=None,
        attn_head_num=query.shape[1],
        attn_dim_per_head=head_dim,
        scale_value=1.0 / (head_dim ** 0.5),
        input_layout="BSND",   # batch-seq-head-dim 排布
        seed=0,
        pre_tokens=seq_len,
        next_tokens=0,
        keep_prob=1.0,       # 推理不 dropout
    )
    return attn_output

HBM 读写次数大幅减少。代价是计算量略增(Softmax 的在线修正需要额外计算),但在现代硬件上计算远比显存访问快,所以总体是赚的。

昇腾 NPU 上的关键差异

到这一步,算法思路是一样的,NVIDIA 和昇腾都这么干。但落到具体实现上,昇腾 NPU 有几个关键差异:

差异一:SRAM 结构不同

NVIDIA GPU 的 SRAM 是 shared memory,一个 thread block 内的线程共享,大小通常 48KB-164KB。昇腾达芬奇架构的 SRAM 叫 Unified Buffer(UB),每个 AI Core 独享,大小是 1.5MB。

UB 比 shared memory 大很多,这意味着分块策略可以不一样。NVIDIA 那边每个 block 处理的 tile 更小,需要更细粒度的分块;昇腾这边 tile 可以更大,减少循环次数。

但 UB 的带宽分配也有讲究。达芬奇架构里,UB 同时要服务于向量计算单元和矩阵计算单元(Cube Unit),如果 FlashAttention 里 Softmax 的向量计算和 QK^T 的矩阵计算争抢 UB 带宽,性能就会打折扣。ops-transformer 里的实现做了一些调度上的优化,尽量让矩阵计算和向量计算流水线化,减少等待。

差异二:矩阵计算单元的指令不同

NVIDIA 的矩阵乘用的是 Tensor Core,通过 WMMA 指令触发。昇腾的矩阵计算单元叫 Cube Unit,通过专门的矩阵乘指令触发。两者的数据排布要求不同:

  • Tensor Core 要求数据按 128x128 的分块排布(FP16 场景下)
  • Cube Unit 要求数据按 16x16 的分块排布(FP16 场景下)

这意味着 Q、K、V 在进入矩阵乘之前要做数据重排(layout transform)。这个重排本身也要消耗算力和带宽,如果做得不精细,重排的开销可能抵消掉 FlashAttention 带来的收益。ops-transformer 里的实现在数据加载阶段就做了 prefetch 和 layout 转换,尽量把这个开销隐藏在计算流水线里。

差异三:Softmax 的在线实现细节

FlashAttention 的核心难点是 Softmax 的在线计算。标准 Softmax 需要先扫一遍求全局最大值(防止数值溢出),再扫一遍算 exp 和归一化。但分块计算的时候,你不知道后面块的最大值是多少,所以需要一种增量更新机制。

NVIDIA 的实现用的是 FlashAttention 论文里的 online softmax 方案,每次处理新块时用当前最大值修正之前的累加结果。昇腾上的实现在算法层面是一样的,但利用了达芬奇架构的向量计算单元做一些并行化的规约操作(reduce),比 GPU 上逐元素串行修正要快。

具体来说,online softmax 的核心逻辑是这样的:

import torch

def online_softmax_update(prev_max, prev_sum, prev_out, cur_scores, cur_values):
    """FlashAttention 中 Softmax 的增量更新逻辑
    每处理一个新的 KV 块,用新块的最大值修正之前的累加结果"""
    # 当前块的最大值
    cur_max = cur_scores.max(dim=-1, keepdim=True).values
    # 全局最大值更新
    new_max = torch.maximum(prev_max, cur_max)
    # 修正之前的累加结果(因为分母变了)
    correction = torch.exp(prev_max - new_max)
    prev_sum_corrected = prev_sum * correction
    prev_out_corrected = prev_out * correction
    # 当前块用新最大值做 Softmax
    cur_weights = torch.exp(cur_scores - new_max)
    cur_sum = cur_weights.sum(dim=-1, keepdim=True)
    cur_out = torch.matmul(cur_weights, cur_values)
    # 合并
    new_sum = prev_sum_corrected + cur_sum
    new_out = (prev_out_corrected + cur_out) / new_sum
    return new_max, new_sum, new_out

在昇腾上,torch.maximumtorch.exp.sum() 这些操作会被编译成 Vector Unit 的单条向量指令,一整行数据并行处理,而 GPU 上需要多个 CUDA thread 协作完成同样的操作。

ops-transformer 里的实现长什么样

ops-transformer 仓库里 FlashAttention 的代码结构大致是这样:

ops-transformer/
└── flash_attention/
    ├── flash_attention_score.py    # 主入口
    ├── flash_attention_grad.py     # 反向传播
    └── kernel/
        ├── flash_attention_tiling.py   # 分块策略
        └── flash_attention_kernel.cpp  # Ascend C 核心实现

核心逻辑在 flash_attention_kernel.cpp 里,用 Ascend C 写的。如果你熟悉 CUDA 编程,看这个文件会有种似曾相识的感觉,但编程模型完全不同。

几个关键点:

Tiling 策略flash_attention_tiling.py 里根据 seq_len、head_dim、UB 容量自动计算最优的 tile 大小。这个策略直接影响性能,太大了 UB 放不下,太小了循环次数多、HBM 访问频繁。

Cube 和 Vector 的流水线:矩阵乘(QK^T、PV)走 Cube Unit,Softmax 和 exp 走 Vector Unit。实现里用双缓冲机制让两套单元交替工作,Cube 算当前块的时候 Vector 在处理上一块的 Softmax。

反向传播:FlashAttention 的反向传播比前向复杂很多,需要保留前向的 Softmax 归一化因子和某些中间结果。ops-transformer 里的反向实现用了重计算策略(recomputation),不把所有中间结果都存下来,而是在反向时重新算一遍需要的中间值,用计算换显存。

实际性能对比

在昇腾 910B 上用 LLaMA 13B 做推理,FlashAttention vs 标准 Attention 的性能差异:

实现 seq_len=2048 seq_len=4096 seq_len=8192
标准 Attention 42ms 156ms OOM
FlashAttention 18ms 38ms 82ms

seq_len 越长,FlashAttention 的优势越明显。8192 的时候标准实现直接 OOM 了,因为中间矩阵放不下。FlashAttention 通过分块计算把显存占用从 O(n²) 降到了 O(n),长序列场景下几乎是唯一的选择。

FlashAttention 看起来只是"把 Attention 分块算",但真正实现起来,每一个硬件差异都要针对性地处理。昇腾 NPU 的 UB 更大、Cube Unit 的数据排布不同、Vector Unit 的并行规约方式不同,这些差异决定了你不能直接把 NVIDIA 的实现搬过来用,得重新设计 tiling 策略和流水线调度。

好消息是 ops-transformer 仓库已经把这些都做好了,而且全面开源。如果你在做大模型推理优化,建议直接用仓库里的实现,不要自己从头写。如果性能还不满足需求,可以在现有实现基础上调 tiling 参数或者改进流水线策略。

理解了 FlashAttention 在昇腾上的实现方式,再看 MoE 算子、MC2 通信算子,思路是一样的:先搞清楚算法核心,再理解硬件差异,最后看具体实现怎么在两者之间做权衡。

Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐