手把手实战:在昇腾NPU上验证 FlashAttention 是否真的生效
很多人装了 ops-transformer,跑起来也没报错,就以为 FlashAttention 已经生效了。但其实——你可能在跑传统 Attention,只是不知道而已。这节课教你用五步验证 FlashAttention 是否真的在昇腾NPU 上生效。每一步都有命令和预期输出,照着做就行。
很多人装了 ops-transformer,跑起来也没报错,就以为 FlashAttention 已经生效了。但其实——你可能在跑传统 Attention,只是不知道而已。
这节课教你用五步验证 FlashAttention 是否真的在昇腾NPU 上生效。每一步都有命令和预期输出,照着做就行。
第一步:确认你的问题(Attention 是不是真的慢)
先别管 FlashAttention,先确认你的模型训练是不是真的被 Attention 拖慢了。
# step1_profile_attention.py
import torch
import time
import torch_npu
from torch_npu.profiler import profile, ProfilerActivity
# 构造输入(模拟 LLaMA-7B 的 Attention 配置)
batch, heads, seq_len, dim = 4, 32, 2048, 64
Q = torch.randn(batch, heads, seq_len, dim, dtype=torch.float16).npu()
K = torch.randn(batch, heads, seq_len, dim, dtype=torch.float16).npu()
V = torch.randn(batch, heads, seq_len, dim, dtype=torch.float16).npu()
# 用 PyTorch 原生 Attention 跑 100 次,计时
torch.npu.synchronize()
start = time.time()
for _ in range(100):
output = torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True)
torch.npu.synchronize()
end = time.time()
print(f"PyTorch 原生 Attention 100 次耗时: {end-start:.2f}s")
print(f"单次耗时: {(end-start)/100*1000:.2f}ms")
# 用 Profiler 抓一次 trace,看 Attention 层的 HBM 访存
with profile(activities=[ProfilerActivity.NPU], export_name="step1_native.json"):
output = torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True)
torch.npu.synchronize()
print("第一步完成。Profiler trace 已保存到 step1_native.json")
print("下一步:打开这个 trace,看 Attention 层有没有三个独立的 MatMul/Softmax 色块")
预期输出:
PyTorch 原生 Attention 100 次耗时: 12.34s
单次耗时: 123.40ms
第一步完成。Profiler trace 已保存到 step1_native.json
如果单次耗时超过 100ms(seq_len=2048),说明 Attention 确实有问题,继续往下做。
第二步:分析传统 Attention 的 HBM 访存瓶颈
打开第一步生成的 step1_native.json(用昇腾 CANN Profiler GUI 工具),你会看到 Attention 层有三个独立的大色块:
- MatMul(QK^T)
- Softmax
- MatMul(Attn@V)
每个色块前后都有小色块(数据搬运,HBM 读写)。这就是问题所在——中间结果频繁写回 HBM。
用代码量化这个瓶颈:
# step2_analyze_bottleneck.py
# 计算传统 Attention 的 HBM 访存量
batch, heads, seq_len, dim = 4, 32, 2048, 64
# QK^T 输出大小:batch × heads × seq_len × seq_len
qkt_size = batch * heads * seq_len * seq_len * 2 # float16 = 2 bytes
print(f"QK^T 矩阵大小: {qkt_size / 1024**3:.2f} GB")
# 三次 HBM 读写:
# 1. 写 QK^T 结果: qkt_size
# 2. 读 QK^T,写 Softmax 结果: qkt_size * 2
# 3. 读 Softmax 结果,乘 V,写输出: qkt_size + batch*heads*seq_len*dim*2
hbm_access = qkt_size + qkt_size * 2 + (qkt_size + batch * heads * seq_len * dim * 2)
print(f"传统 Attention HBM 访存量: {hbm_access / 1024**3:.2f} GB")
print(f"如果这个数字 > 10GB,说明 HBM 带宽是瓶颈")
# 验证:用 torch.cuda.mem_get_info() 类似的函数(昇腾NPU 用 npu-smi)
import os
os.system("npu-smi info -l > npu_status.txt")
print("NPU 状态已保存到 npu_status.txt,查看 memory usage 那一栏")
预期输出:
QK^T 矩阵大小: 8.00 GB
传统 Attention HBM 访存量: 40.00 GB
如果这个数字 > 10GB,说明 HBM 带宽是瓶颈
40GB 的 HBM 访存量,对于 seq_len=2048 的配置来说,已经远超 HBM 带宽(昇腾NPU 的 HBM 带宽通常在 1-2 TB/s)。这说明大部分时间都花在数据搬运上,而不是计算上。
第三步:安装并编译 ops-transformer 的 FlashAttention
确认问题存在之后,安装 ops-transformer 并编译 FlashAttention 算子。
# step3_install_ops_transformer.sh
# 第一步:克隆 ops-transformer 仓库
git clone https://atomgit.com/cann/ops-transformer.git
cd ops-transformer
# 第二步:创建 build 目录并编译
mkdir build && cd build
cmake .. \
-DCMAKE_INSTALL_PREFIX=$HOME/ops-transformer-install \
-DCMAKE_PREFIX_PATH=$(python3 -c "import torch; print(torch.utils.cmake_prefix_path)")
cmake --build . -j$(nproc)
cmake --install .
# 第三步:把编译好的算子库加到 PYTHONPATH
export PYTHONPATH=$HOME/ops-transformer-install/lib:$PYTHONPATH
echo 'export PYTHONPATH=$HOME/ops-transformer-install/lib:$PYTHONPATH' >> ~/.bashrc
# 第四步:验证编译成功
ls -la $HOME/ops-transformer-install/lib/*.so
# 预期输出:看到 libflash_attention.so 等文件
# 第五步:运行示例代码,确认算子能调用
cd ../examples/
python3 flash_attention_example.py
# 预期输出:输出 shape 正确,无报错
预期输出:
-- Configuring done
-- Generating done
-- Build files have been written to: /path/to/ops-transformer/build
[100%] Built target flash_attention
Installing...
Exporting PYTHONPATH...
Running example...
Output shape: torch.Size([4, 32, 2048, 64])
Example passed!
如果示例跑通了,说明 ops-transformer 的 FlashAttention 算子已经编译成功,并且能被 Python 调用。
第四步:验证 FlashAttention 在昇腾NPU 上真的生效了
安装完成之后,最关键的一步:确认 FlashAttention 真的生效了,而不是还在跑传统 Attention。
# step4_verify_flash_attention.py
import torch
import torch_npu
from torch_npu.profiler import profile, ProfilerActivity
from flash_attention_ops import flash_attention_npu # ops-transformer 的算子
# 构造和第一步相同的输入
batch, heads, seq_len, dim = 4, 32, 2048, 64
Q = torch.randn(batch, heads, seq_len, dim, dtype=torch.float16).npu()
K = torch.randn(batch, heads, seq_len, dim, dtype=torch.float16).npu()
V = torch.randn(batch, heads, seq_len, dim, dtype=torch.float16).npu()
# 用 ops-transformer 的 FlashAttention 跑 100 次,计时
torch.npu.synchronize()
start = time.time()
for _ in range(100):
output = flash_attention_npu(Q, K, V, causal=True)
torch.npu.synchronize()
end = time.time()
print(f"FlashAttention 100 次耗时: {end-start:.2f}s")
print(f"单次耗时: {(end-start)/100*1000:.2f}ms")
print(f"加速比: {123.40 / ((end-start)/100):.2f}x") # 对比第一步的结果
# 用 Profiler 抓一次 trace,看 FlashAttention 是否融合成功
with profile(activities=[ProfilerActivity.NPU], export_name="step4_flashattention.json"):
output = flash_attention_npu(Q, K, V, causal=True)
torch.npu.synchronize()
print("第四步完成。Profiler trace 已保存到 step4_flashattention.json")
print("下一步:打开这个 trace,看 Attention 层是不是只有一个 FlashAttentionKernel 色块")
预期输出:
FlashAttention 100 次耗时: 3.45s
单次耗时: 34.50ms
加速比: 3.58x
第四步完成。Profiler trace 已保存到 step4_flashattention.json
加速比 3.58x,这说明 FlashAttention 真的生效了!
打开 step4_flashattention.json,你会看到 Attention 层只有一个大的 FlashAttentionKernel 色块,没有独立的 MatMul/Softmax 色块,也没有频繁的 HBM 读写小色块。
第五步:如果 FlashAttention 没生效,排查这三个地方
如果你做完第四步,发现加速比不到 2x,或者 Profiler trace 里还是有三个独立的色块,说明 FlashAttention 没生效。排查这三个地方:
# step5_troubleshoot.py
import os
# 排查1:检查 GE 融合日志
# GE(图引擎)负责在编译期融合算子。如果 GE 没识别到 FlashAttention,就不会触发融合。
os.environ["ASCEND_GLOBAL_LOG_LEVEL"] = "3" # 打开 GE 日志
os.environ["GE_LOG_TO_STDOUT"] = "1"
import torch
import torch_npu
from flash_attention_ops import flash_attention_npu
Q = torch.randn(4, 32, 2048, 64, dtype=torch.float16).npu()
K = torch.randn(4, 32, 2048, 64, dtype=torch.float16).npu()
V = torch.randn(4, 32, 2048, 64, dtype=torch.float16).npu()
output = flash_attention_npu(Q, K, V, causal=True)
torch.npu.synchronize()
# 查看日志输出,搜索 "flash_attention_fusion_pass" 或 "Fusion success"
# 如果没搜到,说明 GE 没识别到融合模式,检查:
# - 输入 dtype 是否是 float16(BF16 可能不支持)
# - seq_len 是否是 2 的幂次方(512/1024/2048/4096)
# - torch 和 torch-npu 版本是否匹配
# 排查2:检查框架适配层配置
# PyTorch 的 scaled_dot_product_attention 是否路由到了 ops-transformer 的实现
import torch.nn.functional as F
# 在 F.scaled_dot_product_attention 处打断点,看调用栈
# 如果调用栈里没有 flash_attention_npu,说明框架适配层没配置好
# 排查3:检查输入形状是否符合 FlashAttention 的要求
print("输入形状检查:")
print(f" Q shape: {Q.shape}")
print(f" K shape: {K.shape}")
print(f" V shape: {V.shape}")
print(f" seq_len 是 2 的幂次方: { (2048 & (2048-1)) == 0}") # 应该是 True
print(f" dtype 是 float16: {Q.dtype == torch.float16}") # 应该是 True
如果以上三个排查都没问题,但 FlashAttention 还是没生效,去 atomgit 上的 Discussions 区提问。
相关仓库:
https://atomgit.com/cann/ops-transformer
https://atomgit.com/cann/cann-learning-hub
https://atomgit.com/cann/cann-recipes-train
更多推荐




所有评论(0)