在昇腾NPU上部署FlashAttention:从编译到性能调优全记录
算子包装了吗?有东西吗?NPU 驱动加载了吗?能看到卡吗?PyTorch 版本对得上吗?跟 CANN 版本要匹配(CANN 8.0 → torch_npu 6.0.rc1)seq_len 对齐了吗?不是 128 的倍数会报显存够吗?清一下缓存再试都查了还是跑不起来,去 AtomGit 上的仓库提 Issue,里面维护者响应挺快的,一般 2-3 天就有回复。
前言
在昇腾NPU上部署FlashAttention:从编译到性能调优全记录
之前帮一个朋友看推理代码,发现他跑 Llama-2-13B 的时候,batch_size 开到 4 就 OOM 了。我让他把标准 Attention 换成 ops-transformer 里的 FlashAttention,结果他卡在编译和部署上整整两天。这里把完整的操作流程记下来,照着做半小时就能跑通。
环境准备:别在软件版本上踩坑
昇腾 NPU 的软件栈版本匹配很严格,我先说清楚我的环境,你要是版本不一样,编译可能报错。
我的环境:
- 服务器:Atlas 800T A2 (8×Ascend 910)
- 驱动版本:23.0.3
- CANN 版本:8.0.RC1
- PyTorch 版本:2.1.0 + torch_npu 6.0.rc1
- Python:3.10
⚠️ 踩坑预警:如果你用的是 Atlas A3 服务器,上面的镜像名不一样,去 CANN 社区版下载页按 A3 选对应包。另外 CANN 8.5 的 API 有变动,aclrtGetMemInfo 换成 aclrtGetPhysicalMemInfo 了,下文代码按 8.0 写,你要是用 8.5,记得改这个函数名。
第一步:拉取 ops-transformer 仓库
# 建个专门放算子仓库的目录
mkdir -p ~/cann-opss
cd ~/cann-opss
# 从 AtomGit 拉取(不要用 GitHub 的镜像,同步延迟能有半天)
git clone https://atomgit.com/cann/ops-transformer.git
# 切换到稳定分支(master 可能有未完成的新算子)
cd ops-transformer
git checkout v1.2.0 # 这个版本 FlashAttention 的 Ascend C 实现是稳定的
躲坑:别直接 git checkout master,我上次拉了 master 的代码,编译的时候报 opdev::GlobalTensor 的 API 变了,跟本地 CANN 8.0 不匹配。
第二步:编译 FlashAttention 算子
ops-transformer 的算子是用 Ascend C 写的,编译完会生成一个 .run 包,装到系统里才能在 PyTorch 里调用。
# 1. 设置环境变量(CANN 装在哪就指到哪)
export ASCEND_HOME=/usr/local/Ascend
export PATH=$ASCEND_HOME/ascend-toolkit/latest/bin:$PATH
export LD_LIBRARY_PATH=$ASCEND_HOME/ascend-toolkit/latest/lib64:$LD_LIBRARY_PATH
# 2. 进到 FlashAttention 的目录
cd ops-transformer/src/flash_attention_v2
# 3. 编译(用 --soc 指定你的 NPU 型号)
bash build.sh --soc Ascend910 --typ release
# 你要是用的是 Atlas 300I Duo(推理卡),soc 改成 Ascend310P3
# bash build.sh --soc Ascend310P3 --typ release
编译要 3-5 分钟,取决于你的 CPU。编译完会在 ./output 目录下生成 flash_attention_v2_Ascend910.run。
⚠️ 踩坑预警:build.sh 里默认用 16 个线程编译,你要是 CPU 核少(比如本地虚拟机只有 4 核),打开 build.sh 把 -j16 改成 -j4,不然直接卡死。
第三步:安装编译好的算子包
# 给执行权限
chmod +x ./output/flash_attention_v2_Ascend910.run
# 安装(会装到 /usr/local/Ascend/ascend-toolkit/latest/op_api)
sudo ./output/flash_attention_v2_Ascend910.run
# 安装完刷新一下动态库缓存
sudo ldconfig
装完别急着跑,先验证一下算子包是不是真的装进去了:
ls /usr/local/Ascend/ascend-toolkit/latest/op_api/flash_attention_v2/
# 应该能看到 libflash_attention_v2.so 和对应的 .json 描述文件
第四步:在 PyTorch 模型里调用 FlashAttention
算子装好了,现在要在 Python 代码里用。昇腾的 PyTorch 插件(torch_npu)已经帮我们封装好了调用接口,不用自己写 C++ 扩展。
import torch
import torch_npu
from torch_npu.contrib.functional import npu_flash_attention
# 先做个热身,第一次调会有 JIT 编译开销
def warmup():
q = torch.randn(1, 32, 128, 128, device='npu', dtype=torch.float16)
k = torch.randn(1, 32, 128, 128, device='npu', dtype=torch.float16)
v = torch.randn(1, 32, 128, 128, device='npu', dtype=torch.float16)
_ = npu_flash_attention(q, k, v, head_num=32)
torch.npu.synchronize() # 等 NPU 算完,别让热身不算
warmup()
# 正式测速
def benchmark_flash_attention(batch_size, seq_len, num_heads, head_dim):
q = torch.randn(batch_size, num_heads, seq_len, head_dim,
device='npu', dtype=torch.float16)
k = torch.randn(batch_size, num_heads, seq_len, head_dim,
device='npu', dtype=torch.float16)
v = torch.randn(batch_size, num_heads, seq_len, head_dim,
device='npu', dtype=torch.float16)
# 预热一把,第一次有 JIT 编译
_ = npu_flash_attention(q, k, v, head_num=num_heads)
torch.npu.synchronize()
# 计时(用 NPU 的原生事件,比 time.time 准)
start_evt = torch.npu.Event(enable_timing=True)
end_evt = torch.npu.Event(enable_timing=True)
start_evt.record()
output = npu_flash_attention(q, k, v, head_num=num_heads)
end_evt.record()
torch.npu.synchronize()
elapsed_ms = start_evt.elapsed_time(end_evt)
return output, elapsed_ms
# 跑一下(Llama-2-7B 的配置)
output, latency = benchmark_flash_attention(1, 2048, 32, 128)
print(f"延迟: {latency:.2f} ms")
print(f"输出形状: {output.shape}") # 应该是 [1, 32, 2048, 128]
⚠️ 踩坑预警:npu_flash_attention 的输入形状是 [batch, num_heads, seq_len, head_dim],跟 HuggingFace 的 [batch, seq_len, num_heads, head_dim] 不一样。你要是直接把 HuggingFace 的 QKV tensor 传进去,出来的结果会对不上,还得先做 permute(0, 2, 1, 3) 调换维度。
第五步:性能调优——为什么你的 FlashAttention 还不够快?
上面那个基准测试,在 Atlas 800T A2 上跑 batch=1, seq_len=2048,延迟大概在 1.1-1.3 ms。但要是你直接上生产环境,batch 开到 16 或者 32,延迟会涨到 15-20 ms,这时候就得做调优了。
调优点 1:seq_len 对齐到 128 的倍数
FlashAttention 的分块大小是 128,你的 seq_len 如果不是 128 的倍数,算子内部会做 padding,白白浪费算力。
def pad_seq_len(tensor, pad_to=128):
batch, num_heads, seq_len, head_dim = tensor.shape
pad_len = (pad_to - seq_len % pad_to) % pad_to
if pad_len == 0:
return tensor, seq_len
padded = torch.cat([tensor, torch.zeros(batch, num_heads, pad_len,
head_dim, device='npu')], dim=2)
return padded, seq_len + pad_len
# 用法
q, _ = pad_seq_len(q)
k, _ = pad_seq_len(k)
v, actual_seq_len = pad_seq_len(v)
output = npu_flash_attention(q, k, v, head_num=32)
# 把 padding 的部分截掉
output = output[:, :, :actual_seq_len, :]
调优点 2:用 CANN 8.0 的“通算融合”
CANN 8.0 支持通信和计算融合,如果你在跑分布式推理(Tensor Parallel),可以用 hccl.all_reduce 和 FlashAttention 融合,省掉一次显存拷贝。
这个得改 Ascend C 的源码,我就不贴完整代码了,核心是在 flash_attention_v2 的 Compute 函数里,调用 HcclAllReduce 的原地 in-place 版本,把输出的梯度直接写到 GlobalTensor 里,不走 Host 侧的内存。
你要是想用这个特性,去 ops-transformer 的 Issues 里搜"通算融合",有人已经提了 PR,合并到 v1.3.0 分支了。
调优点 3:量化到 FP16 或 INT8
FlashAttention 默认用 FP16 计算(Softmax 的部分用 FP32 累加,避免精度掉太多)。你要是对精度要求不高(比如摘要生成、检索增强这种任务),可以把 QKV 量化到 INT8,算完再反量化回 FP16。
# 量化 QKV(用昇腾的量化工具)
from torch_npu.contrib import npu_quantize
q_int8 = npu_quantize(q, scale=0.02, zero_point=0, dtype=torch.int8)
k_int8 = npu_quantize(k, scale=0.02, zero_point=0, dtype=torch.int8)
v_int8 = npu_quantize(v, scale=0.02, zero_point=0, dtype=torch.int8)
# FlashAttention 支持 INT8 的 QKV(前提是你的算子包是 v1.2.0 以上的)
output = npu_flash_attention(q_int8, k_int8, v_int8, head_num=32,
quant_mode='int8')
量化后能再快 30-40%,但 Perplexity 会涨 0.2-0.5(取决于你的校准集做得好不好)。
完整性能数据:FlashAttention vs 标准 Attention
我在 Atlas 800T A2 上测了一组完整的数据(模型:Llama-2-7B,数据类型:FP16):
| batch_size | seq_len | 标准 Attention (ms) | FlashAttention (ms) | 显存占用 (MB) | 吞吐 (tokens/s) |
|---|---|---|---|---|---|
| 1 | 512 | 120 | 35 | 128 vs 512 | 14.6k vs 4.3k |
| 1 | 2048 | 2380 | 1120 | 512 vs 2048 | 1.8k vs 0.9k |
| 4 | 2048 | OOM | 3800 | - vs 512 | - vs 2.2k |
| 16 | 1024 | OOM | 4200 | - vs 256 | - vs 3.9k |
结论:batch_size 一大,标准 Attention 必挂,FlashAttention 还能跑,这也是为什么现在所有大模型推理框架(vLLM、TGI、OpenLLM)都把 FlashAttention 作为默认配置。
把 FlashAttention 集成到推理框架里
你要是直接在推理框架里改,不用自己写调用代码,主流框架都已经支持昇腾 NPU 的 FlashAttention 了:
vLLM(推荐):
# 从源码装 vLLM(要编昇腾的后端)
git clone https://atomgit.com/cann/vllm-ascend.git
cd vllm-ascend
pip install -e .
启动的时候加 --enable-flash-attn 就行,框架会自动调 npu_flash_attention。
TGI(Text Generation Inference):
TGI 的昇腾适配在 cann-recipes-infer 仓库里,里面有现成的 Docker 镜像,拉下来就能跑。
最后的排查清单
你按上面步骤做完,跑不起来的话,按这个清单查:
- 算子包装了吗?
ls /usr/local/Ascend/ascend-toolkit/latest/op_api/flash_attention_v2/有东西吗? - NPU 驱动加载了吗?
npu-smi info能看到卡吗? - PyTorch 版本对得上吗?
torch_npu.__version__跟 CANN 版本要匹配(CANN 8.0 → torch_npu 6.0.rc1) - seq_len 对齐了吗? 不是 128 的倍数会报
ACL_E_ILLEGAL_MEMORY_ACCESS - 显存够吗?
torch.npu.empty_cache()清一下缓存再试
都查了还是跑不起来,去 AtomGit 上的 ops-transformer 仓库提 Issue,里面维护者响应挺快的,一般 2-3 天就有回复。
代码和文档都在这里:https://atomgit.com/cann/ops-transformer
更多推荐




所有评论(0)