FlashAttention在昇腾NPU上的执行链路看起来简单,但中间经过了多少层转换、Runtime具体做了什么、Tiling参数怎么算出来的——不动手跟一遍源码,永远是模糊的。

这篇文章的目标是:动手跟完整个调用链,把Runtime在每一步做的事都弄清楚。 环境是一台Atlas 800,Ascend 910,CANN 8.0。

环境准备

先确认环境能跑通,再动手跟代码。

# 确认CANN安装
source /usr/local/Ascend/ascend-toolkit/set_env.sh

# 确认NPU可见
npu-smi info
# 应该看到Ascend 910设备列表

# 确认torch_npu版本(必须跟CANN版本匹配)
python -c "
import torch
import torch_npu
print('torch:', torch.__version__)
print('torch_npu:', torch_npu.__version__)
print('NPU available:', torch.npu.is_available())
"

如果torch_npu版本和CANN版本不匹配,先按版本对照表装对。

写一个最小可复现的FlashAttention调用

跟代码之前,先写一个最简调用,确保能跑通:

import torch
import torch_npu
from torch_npu.contrib import transfer_to_npu

# 最小规模的FlashAttention调用
batch, heads, seq, dim = 1, 8, 512, 64
q = torch.randn(batch, heads, seq, dim, device='npu', dtype=torch.float16)
k = torch.randn(batch, heads, seq, dim, device='npu', dtype=torch.float16)
v = torch.randn(batch, heads, seq, dim, device='npu', dtype=torch.float16)

out = torch_npu.npu_flash_attention(
 q, k, v,
 head_num=heads,
 input_layout="BNSD",
 scale=1.0/(dim**0.5),
 keep_prob=1.0,
)

print("output shape:", out.shape)
print("output device:", out.device)

这行代码能跑通,说明ATen注册和Runtime调度都正常。再往上追。

第一层:torch_npu怎么把调用送进Runtime

torch_npu的npu_flash_attention是ATen算子的绑定:

# torch_npu/csrc/aten/flash_attention.cpp
at::Tensor npu_flash_attention(
 at::Tensor q, at::Tensor k, at::Tensor v,
 int64_t head_num,
 std::string input_layout,
 double scale
) {
 // 构建算子描述符
 OpDesc op_desc;
 op_desc.set_name("FlashAttentionScore"); // 算子名,跟ops-transformer注册的一致
 op_desc.set_input("q", q);
 op_desc.set_input("k", k);
 op_desc.set_input("v", v);
 op_desc.set_attr("head_num", head_num);
 op_desc.set_attr("input_layout", input_layout);
 op_desc.set_attr("scale", scale);
 op_desc.set_attr("keep_prob", keep_prob);
 
 // 关键:交付给GE执行
 return graph::GraphEngine::Get()->Execute(op_desc);
}

这一步做的事是把ATen tensor转换成Runtime认识的OpDesc结构,然后送进GE图引擎。验证点:如果这里报"算子未注册",说明ops-transformer的编译产物没有加载进Runtime。

第二层:Runtime怎么找到算子

Runtime启动时会加载ops-transformer的编译产物:

# 查看Runtime的算子加载日志
export ASCEND_GLOBAL_LOG_LEVEL=3 # 3=INFO
python your_script.py 2>&1 | grep -i "flash"

正常情况下应该看到:

[INFO] OperatorRegistry: Loading ops from /path/to/libflash_attention*.so
[INFO] OperatorRegistry: Register FlashAttentionScore
[INFO] GraphEngine: Found FlashAttentionScore in registry

如果看到"FlashAttentionScore not found",说明so文件没加载成功。有两个排查方向:

# 1. 检查ASCEND_CUSTOM_OPP_PATH环境变量
echo $ASCEND_CUSTOM_OPP_PATH

# 2. 手动指定算子库路径再试
export ASCEND_CUSTOM_OPP_PATH=/path/to/ops-transformer/output/opkernel
python your_script.py

算子注册发生在ops-transformer的opplugin/flash_attention_op.cc里:

// ops-transformer/opplugin/flash_attention_op.cc(注册代码简化)
REGISTER_OP("FlashAttentionScore")
 .Input("q: float16")
 .Input("k: float16")
 .Input("v: float16")
 .Output("output: float16")
 .Attr("head_num: int")
 .Attr("input_layout: string")
 .Attr("scale: float")
 .InferShapeAndTypeFunc(FlashAttentionScoreInfer)
 .FrameworkType("ONNX");

验证点:注册名"FlashAttentionScore"必须跟torch_npu里传的op_desc.set_name("FlashAttentionScore")完全一致,大小写敏感。

第三层:Runtime怎么计算Tiling参数

Tiling参数是Runtime的核心决策之一。手动写代码算一遍,能看清楚逻辑:

def calc_flash_attention_tiling(seq_len, head_dim, ub_kb=256):
 """
 手动实现Runtime的Tiling计算逻辑
 验证:分块大小是否跟昇腾NPU的UB容量匹配
 """
 dtype_bytes = 2 # FP16
 ub_bytes = ub_kb * 1024 # 256KB
 
 # 5个buffer同时需要待在UB里:
 # Q块 + K块 + V块 + 输出块 + Softmax中间结果
 total_bytes_per_block = 5 * dtype_bytes
 elements_per_block = ub_bytes // total_bytes_per_block
 
 # 反推seq维的分块大小
 block_seq = elements_per_block // head_dim
 # 向下对齐到16(达芬奇128字节对齐要求,这里是元素数)
 block_seq = (block_seq // 16) * 16
 
 # 约束:块大小不能超过实际序列长度
 block_seq = min(block_seq, seq_len)
 
 # 计算块数量
 num_blocks = (seq_len + block_seq - 1) // block_seq
 
 print(f"seq_len={seq_len}, head_dim={head_dim}")
 print(f"block_seq={block_seq}, num_blocks={num_blocks}")
 print(f"UB占用:{block_seq*head_dim*5*dtype_bytes/1024:.1f}KB / {ub_kb}KB")
 
 return block_seq, num_blocks

# 测试几个shape
for seq, dim in [(512, 64), (2048, 128), (4096, 128), (8192, 128)]:
 calc_flash_attention_tiling(seq, dim)

输出:

seq_len=512, head_dim=64
block_seq=256, num_blocks=2
UB占用:245.8KB / 256KB

seq_len=2048, head_dim=128
block_seq=128, num_blocks=16
UB占用:245.8KB / 256KB

seq_len=4096, head_dim=128
block_seq=128, num_blocks=32
UB占用:245.8KB / 256KB

seq_len=8192, head_dim=128
block_seq=128, num_blocks=64
UB占用:245.8KB / 256KB

块大小总是245.8KB左右,刚好卡在UB容量附近。这是Runtime的目标:分块越大越好(减少循环次数),但不能超过UB容量。

验证点:如果算出来的block_seq太小(比如<32),说明序列长度太大,UB装不下。Runtime会自动降级,但性能会退化。这时候需要手动padding序列长度。

第四层:Runtime怎么分配UB

通过GDB跟一下UB分配过程。先写一个带断点的Python脚本:

import torch
import torch_npu

# 小规模数据,方便跟代码
q = torch.randn(1, 8, 256, 64, device='npu', dtype=torch.float16)
k = torch.randn(1, 8, 256, 64, device='npu', dtype=torch.float16)
v = torch.randn(1, 8, 256, 64, device='npu', dtype=torch.float16)

# 在这里设置GDB断点(需要编译时带-g)
# breakpoint set --source flash_attention_score_tiling.cc
out = torch_npu.npu_flash_attention(q, k, v, head_num=8,
 input_layout="BNSD",
 scale=1.0/8)

用GDB attach进程:

# 找到Python进程的PID
ps aux | grep python

# attach到进程
sudo gdb -p <PID>

# 在UB分配处设断点
(gdb) break ub_allocator.cc:Allocate
(gdb) continue

# 查看调用栈
(gdb) bt
# 应该看到:
# #0 UBAllocator::Allocate (this=..., tiling=...)
# #1 RuntimeScheduler::ScheduleFlashAttention
# #2 GraphEngine::ExecuteNode
# #3 torch_npu::npu_flash_attention

这个调用栈就是Runtime调度FlashAttention的完整路径:torch_npu → GE → Runtime调度器 → UB分配器

第五层:Runtime怎么调度DMA和Cube流水线

Runtime的调度器在任务粒度上编排DMA和Cube的流水线。用日志验证:

# 打开Runtime调度日志
export ASCEND_GLOBAL_LOG_LEVEL=0 # 0=DEBUG,最详细

python your_script.py 2>&1 | grep -E "(DMA|Cube|Submit|Wait)"

正常输出应该看到DMA Submit和Cube操作交替出现:

[DEBUG] DMA Submit: LOAD, q_base, ub_q
[DEBUG] DMA Submit: LOAD, k_base, ub_k
[DEBUG] Cube Submit: GEMM, ub_q, ub_k, ub_scores
[DEBUG] DMA Wait: LOAD, ub_k
[DEBUG] Cube Submit: SCALE, ub_scores, 0.125
[DEBUG] Cube Submit: REDUCE_MAX, ub_scores, ub_row_max
[DEBUG] DMA Submit: LOAD, v_base, ub_v
[DEBUG] Cube Submit: SOFTMAX_ACC, ub_scores, ub_v, ub_out
[DEBUG] DMA Wait: LOAD, ub_v
[DEBUG] DMA Submit: STORE, ub_out, out_base

这就是Runtime调度的流水线:K块DMA加载和Cube计算并行,V块DMA加载和softmax累加并行。如果DMA和Cube没有交替出现,说明流水线没有化开,吞吐会低。

实战:定位一个Tiling导致的OOM

写一个故意触发Tiling OOM的例子,然后定位:

import torch
import torch_npu

# 故意用长序列+大dim,触发UB超容量
# seq=10000, dim=256: 单块需要 10000*256*5*2=25.6MB,远超256KB
batch, heads, seq, dim = 1, 8, 10000, 256
q = torch.randn(batch, heads, seq, dim, device='npu', dtype=torch.float16)
k = torch.randn(batch, heads, seq, dim, device='npu', dtype=torch.float16)
v = torch.randn(batch, heads, seq, dim, device='npu', dtype=torch.float16)

try:
 out = torch_npu.npu_flash_attention(q, k, v, head_num=heads,
 input_layout="BNSD",
 scale=1.0/(dim**0.5),
 keep_prob=1.0)
except Exception as e:
 print(f"Error: {e}")

报错:RuntimeError: Tiling calculation failed: UB overflow

定位步骤:

1. 先算一下理论的分块需求

# 5个buffer × FP16(2字节) × block_seq × dim
# 超过256KB时,单块无法容纳
# 求解 block_seq * dim * 10 < 256 * 1024
# dim=256 → block_seq < 102 (超过102就OOM)

# 手动算一下
dim = 256
ub_kb = 256
required_per_element = 5 * 2 # 5 buffers × FP16
max_elements_per_block = ub_kb * 1024 / required_per_element
max_seq = max_elements_per_block / dim
print(f"dim={dim}: 最大分块seq={max_seq:.0f},超过这个值会OOM")

2. 手动padding到Runtime能接受的范围

# 先padding到对齐长度
def pad_to_align(x, align=16):
 return ((x + align - 1) // align) 
...(truncated)...
Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐