在昇腾NPU上追踪FlashAttention的Runtime调用链:从一行Python到硬件执行
FlashAttention在昇腾NPU上的执行链路看起来简单,但中间经过了多少层转换、Runtime具体做了什么、Tiling参数怎么算出来的——不动手跟一遍源码,永远是模糊的。环境是一台Atlas 800,Ascend 910,CANN 8.0。
FlashAttention在昇腾NPU上的执行链路看起来简单,但中间经过了多少层转换、Runtime具体做了什么、Tiling参数怎么算出来的——不动手跟一遍源码,永远是模糊的。
这篇文章的目标是:动手跟完整个调用链,把Runtime在每一步做的事都弄清楚。 环境是一台Atlas 800,Ascend 910,CANN 8.0。
环境准备
先确认环境能跑通,再动手跟代码。
# 确认CANN安装
source /usr/local/Ascend/ascend-toolkit/set_env.sh
# 确认NPU可见
npu-smi info
# 应该看到Ascend 910设备列表
# 确认torch_npu版本(必须跟CANN版本匹配)
python -c "
import torch
import torch_npu
print('torch:', torch.__version__)
print('torch_npu:', torch_npu.__version__)
print('NPU available:', torch.npu.is_available())
"
如果torch_npu版本和CANN版本不匹配,先按版本对照表装对。
写一个最小可复现的FlashAttention调用
跟代码之前,先写一个最简调用,确保能跑通:
import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu
# 最小规模的FlashAttention调用
batch, heads, seq, dim = 1, 8, 512, 64
q = torch.randn(batch, heads, seq, dim, device='npu', dtype=torch.float16)
k = torch.randn(batch, heads, seq, dim, device='npu', dtype=torch.float16)
v = torch.randn(batch, heads, seq, dim, device='npu', dtype=torch.float16)
out = torch_npu.npu_flash_attention(
q, k, v,
head_num=heads,
input_layout="BNSD",
scale=1.0/(dim**0.5),
keep_prob=1.0,
)
print("output shape:", out.shape)
print("output device:", out.device)
这行代码能跑通,说明ATen注册和Runtime调度都正常。再往上追。
第一层:torch_npu怎么把调用送进Runtime
torch_npu的npu_flash_attention是ATen算子的绑定:
# torch_npu/csrc/aten/flash_attention.cpp
at::Tensor npu_flash_attention(
at::Tensor q, at::Tensor k, at::Tensor v,
int64_t head_num,
std::string input_layout,
double scale
) {
// 构建算子描述符
OpDesc op_desc;
op_desc.set_name("FlashAttentionScore"); // 算子名,跟ops-transformer注册的一致
op_desc.set_input("q", q);
op_desc.set_input("k", k);
op_desc.set_input("v", v);
op_desc.set_attr("head_num", head_num);
op_desc.set_attr("input_layout", input_layout);
op_desc.set_attr("scale", scale);
op_desc.set_attr("keep_prob", keep_prob);
// 关键:交付给GE执行
return graph::GraphEngine::Get()->Execute(op_desc);
}
这一步做的事是把ATen tensor转换成Runtime认识的OpDesc结构,然后送进GE图引擎。验证点:如果这里报"算子未注册",说明ops-transformer的编译产物没有加载进Runtime。
第二层:Runtime怎么找到算子
Runtime启动时会加载ops-transformer的编译产物:
# 查看Runtime的算子加载日志
export ASCEND_GLOBAL_LOG_LEVEL=3 # 3=INFO
python your_script.py 2>&1 | grep -i "flash"
正常情况下应该看到:
[INFO] OperatorRegistry: Loading ops from /path/to/libflash_attention*.so
[INFO] OperatorRegistry: Register FlashAttentionScore
[INFO] GraphEngine: Found FlashAttentionScore in registry
如果看到"FlashAttentionScore not found",说明so文件没加载成功。有两个排查方向:
# 1. 检查ASCEND_CUSTOM_OPP_PATH环境变量
echo $ASCEND_CUSTOM_OPP_PATH
# 2. 手动指定算子库路径再试
export ASCEND_CUSTOM_OPP_PATH=/path/to/ops-transformer/output/opkernel
python your_script.py
算子注册发生在ops-transformer的opplugin/flash_attention_op.cc里:
// ops-transformer/opplugin/flash_attention_op.cc(注册代码简化)
REGISTER_OP("FlashAttentionScore")
.Input("q: float16")
.Input("k: float16")
.Input("v: float16")
.Output("output: float16")
.Attr("head_num: int")
.Attr("input_layout: string")
.Attr("scale: float")
.InferShapeAndTypeFunc(FlashAttentionScoreInfer)
.FrameworkType("ONNX");
验证点:注册名"FlashAttentionScore"必须跟torch_npu里传的op_desc.set_name("FlashAttentionScore")完全一致,大小写敏感。
第三层:Runtime怎么计算Tiling参数
Tiling参数是Runtime的核心决策之一。手动写代码算一遍,能看清楚逻辑:
def calc_flash_attention_tiling(seq_len, head_dim, ub_kb=256):
"""
手动实现Runtime的Tiling计算逻辑
验证:分块大小是否跟昇腾NPU的UB容量匹配
"""
dtype_bytes = 2 # FP16
ub_bytes = ub_kb * 1024 # 256KB
# 5个buffer同时需要待在UB里:
# Q块 + K块 + V块 + 输出块 + Softmax中间结果
total_bytes_per_block = 5 * dtype_bytes
elements_per_block = ub_bytes // total_bytes_per_block
# 反推seq维的分块大小
block_seq = elements_per_block // head_dim
# 向下对齐到16(达芬奇128字节对齐要求,这里是元素数)
block_seq = (block_seq // 16) * 16
# 约束:块大小不能超过实际序列长度
block_seq = min(block_seq, seq_len)
# 计算块数量
num_blocks = (seq_len + block_seq - 1) // block_seq
print(f"seq_len={seq_len}, head_dim={head_dim}")
print(f"block_seq={block_seq}, num_blocks={num_blocks}")
print(f"UB占用:{block_seq*head_dim*5*dtype_bytes/1024:.1f}KB / {ub_kb}KB")
return block_seq, num_blocks
# 测试几个shape
for seq, dim in [(512, 64), (2048, 128), (4096, 128), (8192, 128)]:
calc_flash_attention_tiling(seq, dim)
输出:
seq_len=512, head_dim=64
block_seq=256, num_blocks=2
UB占用:245.8KB / 256KB
seq_len=2048, head_dim=128
block_seq=128, num_blocks=16
UB占用:245.8KB / 256KB
seq_len=4096, head_dim=128
block_seq=128, num_blocks=32
UB占用:245.8KB / 256KB
seq_len=8192, head_dim=128
block_seq=128, num_blocks=64
UB占用:245.8KB / 256KB
块大小总是245.8KB左右,刚好卡在UB容量附近。这是Runtime的目标:分块越大越好(减少循环次数),但不能超过UB容量。
验证点:如果算出来的block_seq太小(比如<32),说明序列长度太大,UB装不下。Runtime会自动降级,但性能会退化。这时候需要手动padding序列长度。
第四层:Runtime怎么分配UB
通过GDB跟一下UB分配过程。先写一个带断点的Python脚本:
import torch
import torch_npu
# 小规模数据,方便跟代码
q = torch.randn(1, 8, 256, 64, device='npu', dtype=torch.float16)
k = torch.randn(1, 8, 256, 64, device='npu', dtype=torch.float16)
v = torch.randn(1, 8, 256, 64, device='npu', dtype=torch.float16)
# 在这里设置GDB断点(需要编译时带-g)
# breakpoint set --source flash_attention_score_tiling.cc
out = torch_npu.npu_flash_attention(q, k, v, head_num=8,
input_layout="BNSD",
scale=1.0/8)
用GDB attach进程:
# 找到Python进程的PID
ps aux | grep python
# attach到进程
sudo gdb -p <PID>
# 在UB分配处设断点
(gdb) break ub_allocator.cc:Allocate
(gdb) continue
# 查看调用栈
(gdb) bt
# 应该看到:
# #0 UBAllocator::Allocate (this=..., tiling=...)
# #1 RuntimeScheduler::ScheduleFlashAttention
# #2 GraphEngine::ExecuteNode
# #3 torch_npu::npu_flash_attention
这个调用栈就是Runtime调度FlashAttention的完整路径:torch_npu → GE → Runtime调度器 → UB分配器。
第五层:Runtime怎么调度DMA和Cube流水线
Runtime的调度器在任务粒度上编排DMA和Cube的流水线。用日志验证:
# 打开Runtime调度日志
export ASCEND_GLOBAL_LOG_LEVEL=0 # 0=DEBUG,最详细
python your_script.py 2>&1 | grep -E "(DMA|Cube|Submit|Wait)"
正常输出应该看到DMA Submit和Cube操作交替出现:
[DEBUG] DMA Submit: LOAD, q_base, ub_q
[DEBUG] DMA Submit: LOAD, k_base, ub_k
[DEBUG] Cube Submit: GEMM, ub_q, ub_k, ub_scores
[DEBUG] DMA Wait: LOAD, ub_k
[DEBUG] Cube Submit: SCALE, ub_scores, 0.125
[DEBUG] Cube Submit: REDUCE_MAX, ub_scores, ub_row_max
[DEBUG] DMA Submit: LOAD, v_base, ub_v
[DEBUG] Cube Submit: SOFTMAX_ACC, ub_scores, ub_v, ub_out
[DEBUG] DMA Wait: LOAD, ub_v
[DEBUG] DMA Submit: STORE, ub_out, out_base
这就是Runtime调度的流水线:K块DMA加载和Cube计算并行,V块DMA加载和softmax累加并行。如果DMA和Cube没有交替出现,说明流水线没有化开,吞吐会低。
实战:定位一个Tiling导致的OOM
写一个故意触发Tiling OOM的例子,然后定位:
import torch
import torch_npu
# 故意用长序列+大dim,触发UB超容量
# seq=10000, dim=256: 单块需要 10000*256*5*2=25.6MB,远超256KB
batch, heads, seq, dim = 1, 8, 10000, 256
q = torch.randn(batch, heads, seq, dim, device='npu', dtype=torch.float16)
k = torch.randn(batch, heads, seq, dim, device='npu', dtype=torch.float16)
v = torch.randn(batch, heads, seq, dim, device='npu', dtype=torch.float16)
try:
out = torch_npu.npu_flash_attention(q, k, v, head_num=heads,
input_layout="BNSD",
scale=1.0/(dim**0.5),
keep_prob=1.0)
except Exception as e:
print(f"Error: {e}")
报错:RuntimeError: Tiling calculation failed: UB overflow
定位步骤:
1. 先算一下理论的分块需求
# 5个buffer × FP16(2字节) × block_seq × dim
# 超过256KB时,单块无法容纳
# 求解 block_seq * dim * 10 < 256 * 1024
# dim=256 → block_seq < 102 (超过102就OOM)
# 手动算一下
dim = 256
ub_kb = 256
required_per_element = 5 * 2 # 5 buffers × FP16
max_elements_per_block = ub_kb * 1024 / required_per_element
max_seq = max_elements_per_block / dim
print(f"dim={dim}: 最大分块seq={max_seq:.0f},超过这个值会OOM")
2. 手动padding到Runtime能接受的范围
# 先padding到对齐长度
def pad_to_align(x, align=16):
return ((x + align - 1) // align)
...(truncated)...更多推荐




所有评论(0)