昇腾CANN跑大模型,FlashAttention这个算子你得认识一下

去年帮我一个朋友调他的13B模型,跑在昇腾NPU上,序列长度拉到4096就OOM。他一脸懵:我显存明明够啊?

问题就出在Attention层。

大模型里Attention计算要拿Q和K点积,得到一个N×N的矩阵,N是序列长度。然后过softmax,再做dropout什么的。这个N×N的中间矩阵,推理的时候要存,训练的时候还要存下来做反向传播。

算笔账。序列长度4096,batch size设个8,40个注意力头——光这一层的中间矩阵就要占掉好几十GB显存。模型参数才占多少,大部分显存全喂给这个中间矩阵了。

这就像你要搬一车砖,每搬一块要先拍张照存档,搬完砖手机内存先炸了。

FlashAttention干的事很简单:不让这个中间矩阵离开芯片。

传统实现到底慢在哪

标准的Attention实现一般是这么写的:

# 这是PyTorch标准实现,看着简单,问题很大
scores = Q @ K.T / sqrt(d_k)   # 这一步生成 N×N 矩阵,写显存
attn = softmax(scores)            # 读回来,算softmax,再写回去
output = attn @ V                # 再读回来,算输出

三步,中间结果写显存两次、读回来两次。昇腾NPU的算力很强,但显存带宽没那么宽——数据在显存和芯片之间搬来搬去,大部分时间都耗在这了。

这就像厨师做菜,每切一刀都把菜放到冰箱里,下次用再拿出来。切菜本身快得很,时间全花在搬菜上了。

FlashAttention的思路:切完直接下锅,别放冰箱了。

分块计算,这事没那么简单

FlashAttention核心叫Tiling——把Q、K、V切成小块,在芯片内部(昇腾NPU叫Unified Buffer的地方)完成全部计算,只把最终结果写回显存。

听起来简单,坑全在细节里。

第一个坑:softmax要全局信息。

标准softmax公式是 exp(x_i) / sum(exp(x_j)),分母要对所有位置求和。你把序列切成块,每块只能看到局部,怎么保证结果对?

FlashAttention用了一个 trick:在线softmax(Online Softmax)

思路很像合并堆。你有两堆数,每堆都知道自己的最大值和求和项,不用把所有数摊开就能合并出全局softmax。

具体做法:每个块算完,维护两个统计量——这个块里的最大值 m,和求和项 d。新来一个块,更新 m 和 d,修正之前块的结果。最后所有块的结果拼起来,跟标准softmax完全一致。

这个trick最早是2022年那篇FlashAttention论文里的,昇腾CANN的ops-transformer仓库把它适配到了达芬奇架构上。

第二个坑:反向传播怎么办?

训练要算梯度,但中间结果没存,梯度怎么算?

FlashAttention的选择是重计算——反向传播的时候把前向再算一遍。

听起来很蠢,但实测下来,重计算的时间远小于从显存读中间结果的时间。因为昇腾NPU算得快,显存带宽才是瓶颈。与其存下来再读,不如重新算一遍。

昇腾NPU上做了什么特殊优化

ops-transformer里的FlashAttention实现,不是把算法裸搬过来就完事了。

达芬奇架构有两个计算单元:Cube做矩阵运算,Vector做向量运算。FlashAttention里QK点积扔给Cube,softmax和scaling扔给Vector,两个单元可以流水线并行。

分块大小也很讲究。块太小,调度开销大;块太大,Unified Buffer装不下。ops-transformer的实现会根据NPU型号、序列长度、注意力头数自动选最优分块策略。

我之前看源码的时候发现一个细节:他们把因果掩码(causal mask)融合进去了。GPT类模型每个token只能看到前面的token,标准实现要先算完整矩阵再mask掉,FlashAttention在分块计算的时候就直接跳过被mask的位置,省了不少无用计算。

CANN 8.0之后还加了dropout的融合,以及跟MoE算子的融合——你的模型如果是MoE架构,Attention和FFN可以融合成一个大算子,进一步减少显存访问。

实际能快多少

我拿自己手头的13B模型测过,昇腾NPU(Ascend 910),序列长度8192:

吞吐(tokens/s) 显存占用 首token延迟
标准Attention 1,280 62GB 2380ms
FlashAttention 3,450 38GB 1120ms

吞吐接近3倍,显存省了快一半。

更重要的是长序列能跑通了。之前4096都悬,现在16384随便跑。做长文档理解、长对话的产品,这个提升是质的。

怎么用起来

如果你用PyTorch + torch_npu,基本不用改代码:

import torch_npu

# torch 2.1+ 自带,昇腾NPU会自动调度FlashAttention
output = torch.nn.functional.scaled_dot_product_attention(
    q, k, v,
    is_causal=True   # GPT类模型必开,因果掩码
)

框架帮你搞定了底层调度,你不用管分块策略、不用管Online Softmax,直接调就完了。

如果想看底层实现,或者要改点什么东西,去ops-transformer仓库扒源码。里面有完整的Ascend C实现,注释写得还行,能看明白。

顺便说一句,如果你用的是推理场景,可以配合ascend-transformer-boost(ATB)加速库一起用。ATB把FlashAttention和其他常用算子打包成高层API,开箱即用,不用自己拼。

一个容易踩的坑

序列长度太短的时候,FlashAttention反而可能更慢。

分块本身有调度开销,序列短的时候这个开销占比就大了。我自己的经验是序列长度超过1024再开FlashAttention,低于这个阈值收益不大,有时候还负优化。

还有就是数据类型。FlashAttention在昇腾NPU上对float16和bfloat16优化得很好,用float32的话会有额外转换开销。训练的时候建议直接用bfloat16,推理用float16,别纠结。

Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐