昇腾CANN的算子“零件厂“:catlass仓库到底在生产什么
如果把昇腾NPU上的大模型算子比作一辆汽车,FlashAttention是发动机,RMSNorm是刹车片,RoPE是方向盘——那catlass是什么?是生产这些零件的模具和机床。第一次接触昇腾CANN生态的时候,很容易忽略catlass。它不像ops-transformer那样直接提供FlashAttention这种"成品算子",也不像torch_npu那样可以直接调用。catlass藏在更底层,
如果把昇腾NPU上的大模型算子比作一辆汽车,FlashAttention是发动机,RMSNorm是刹车片,RoPE是方向盘——那catlass是什么?是生产这些零件的模具和机床。
第一次接触昇腾CANN生态的时候,很容易忽略catlass。它不像ops-transformer那样直接提供FlashAttention这种"成品算子",也不像torch_npu那样可以直接调用。catlass藏在更底层,提供的是算子开发的基础模板——分块矩阵乘、softmax归一化、数据搬运策略。FlashAttention里的在线softmax和分块计算,底层都靠catlass的模板撑着。
catlass不是CUTLASS
这个名字容易让人误解。NVIDIA生态里有个著名的CUTLASS(CUDA Templates for Linear Algebra Subroutines),是英伟达的矩阵运算模板库。catlass的命名确实有致敬CUTLASS的意味,但它跟CUTLASS没有代码层面的关系,也不是CUTLASS的昇腾移植版。
catlass是昇腾NPU原生的算子模板库,专门针对达芬奇架构的硬件特性设计:Unified Buffer的容量约束、Cube和Vector计算单元的流水编排、DMA搬运和计算的重叠。这些在CUDA上是另一套逻辑,没法直接搬过来。
仓库定位:算子开发的"基础设施"
昇腾CANN的算子生态分三层,catlass在最底层:
应用算子层:ops-transformer(FlashAttention、RMSNorm等成品算子)
↑ 依赖
算子模板层:catlass(分块矩阵乘、reduce、softmax模板)
↑ 依赖
硬件抽象层:Ascend C(达芬奇架构的编程接口)
打个比方:catlass提供砖头和水泥,ops-transformer用这些材料盖房子。想在昇腾NPU上开发新的FlashAttention变体?先得理解catlass怎么生产"砖头"。
仓库结构:模板在哪里
git clone https://atomgit.com/cann/catlass.git
cd catlass
核心目录:
catlass/
├── include/
│ └── catlass/
│ ├── gemm/ # 分块矩阵乘模板(核心)
│ │ ├── kernel/
│ │ │ ├── gemm_split_k.h # Split-K并行策略
│ │ │ └── gemm_batched.h # 批量矩阵乘
│ │ ├── thread/
│ │ │ └── mma.h # 矩阵乘累加
│ │ └── collective/
│ │ └── sm80_gemm.h # 分块搬运+计算流水
│ ├── reduction/ # 归约操作模板
│ │ ├── thread/
│ │ │ └── reduce.h # 单块内reduce
│ │ └── block/
│ │ └── reduce.h # 跨块reduce
│ ├── softmax/ # softmax模板
│ │ └── online_softmax.h # 在线softmax(FlashAttention的基石)
│ └── layout/ # 数据布局适配
│ └── layout.h # BNSD/BSND等格式转换
├── examples/ # 使用示例
│ ├── flash_attention/ # 用catlass搭FlashAttention
│ └── simple_gemm/ # 基础矩阵乘示例
└── tests/ # 单元测试
online_softmax.h和gemm/是FlashAttention开发者最需要关注的两个模块。
分块矩阵乘:catlass的发动机
FlashAttention的核心操作是Q×K^T和注意力权重×V,本质上都是分块矩阵乘。catlass的GEMM模板把分块策略、数据搬运、计算流水全部封装好:
// catlass/include/catlass/gemm/kernel/gemm_split_k.h
// 分块矩阵乘的简化示意(伪代码)
template<typename ElementA, typename ElementB, typename ElementC>
struct GemmSplitK {
// 分块参数:决定一次算多大的块
struct TileShape {
static const int kM = 128; // M方向分块大小
static const int kN = 128; // N方向分块大小
static const int kK = 64; // K方向分块大小
};
// 整个GEMM分三阶段流水
void operator()(Params const& params) {
// 阶段1:从GMEM加载A/B分块到L1/UB
LoadTileFromGlobal(params.A, params.B, tile_a, tile_b);
// 阶段2:Cube单元做矩阵乘累加
// 阶段1和阶段2流水化:加载第(i+1)块的同时计算第i块
for (int k = 0; k < params.K / TileShape::kK; k++) {
LoadNextTile(...); // DMA搬运
ComputeCurrentTile(tile_a, tile_b, acc); // Cube计算
}
// 阶段3:结果从UB写回GMEM
StoreTileToGlobal(params.C, acc);
}
};
这里最关键的设计是搬运和计算的流水化。昇腾NPU有独立的DMA引擎和计算单元,加载下一块数据的同时可以计算当前块,两者并行。如果等加载完再算、算完再加载,吞吐直接砍半。
分块大小128×128×64不是随便定的。昇腾910的Cube单元一次能处理16×16的FP16矩阵乘,128×128正好是8×8个Cube微操作,UB装得下,且对齐到128字节边界。
在线softmax:FlashAttention的数学基石
catlass里的online_softmax.h是在线softmax的模板实现,FlashAttention的前向计算直接依赖它:
// catlass/include/catlass/softmax/online_softmax.h
// 在线softmax模板(伪代码,示意流程)
template<typename T, int BlockSize>
struct OnlineSoftmax {
struct State {
T row_max; // 当前行最大值(防溢出用)
T row_sum; // 当前行指数和
T* acc_output; // 累加输出指针
};
// 每处理一个新块就调用一次Update
static void Update(State& state, T* new_scores, T* new_values, int block_len) {
// 找新块的局部最大值
T local_max = ReduceMax(new_scores, block_len);
// 更新全局最大值
T new_max = Max(state.row_max, local_max);
// 关键:重新缩放之前的累加结果
// 数学上等价于把所有分数放到同一个exp尺度下
T correction = Exp(state.row_max - new_max);
state.row_sum = state.row_sum * correction;
ScaleAccOutput(state.acc_output, correction);
// 加上新块的贡献
// 减new_max防溢出:FP16下exp(>88.7)=inf
T* exp_scores = ExpSub(new_scores, new_max, block_len);
T local_sum = ReduceSum(exp_scores, block_len);
state.row_sum = state.row_sum + local_sum;
AccumulateOutput(state.acc_output, exp_scores, new_values);
state.row_max = new_max;
}
// 所有块处理完后归一化
static void Finalize(State& state, int seq_len) {
ScaleAccOutput(state.acc_output, 1.0 / state.row_sum);
}
};
correction = Exp(state.row_max - new_max)这一行是在线softmax的精髓。标准softmax需要先扫一遍全局最大值,再扫一遍算指数和。在线softmax把它变成增量的:每来一个新块,如果发现了更大的值,就把之前所有结果按比例缩放回来。数学上完全等价,但不需要存全局信息。
FlashAttention的前向计算就是:外层循环遍历Q分块,内层循环遍历K/V分块,每一步调catlass的OnlineSoftmax::Update更新累加结果。
catlass跟ops-transformer的协作
实际代码里,ops-transformer的FlashAttention算子直接引用catlass的头文件:
ops-transformer/opkernel/flash_attention/
├── flash_attention_score.cc
│ #include "catlass/gemm/kernel/gemm_split_k.h" // 分块矩阵乘
│ #include "catlass/softmax/online_softmax.h" // 在线softmax
│ #include "catlass/reduction/block/reduce.h" // 跨块reduce
开发流程是这样的:
- catlass定义
OnlineSoftmax、GemmSplitK等模板 - ops-transformer在FlashAttention算子里实例化这些模板,传入昇腾NPU的硬件参数(UB大小、Cube规格等)
- 编译时模板展开,生成针对910硬件优化的机器码
这种分层设计的好处:改FlashAttention的算法逻辑(比如加个causal mask),只改ops-transformer;改分块策略或softmax的数值精度,改catlass。互不干扰。
从catlass搭建一个简化版FlashAttention
catlass的examples目录里有一个简化版FlashAttention示例,展示了怎么用模板搭出完整算子:
// catlass/examples/flash_attention/flash_attention_example.cc
// 简化示意
template<typename T>
void FlashAttentionForward(
T* query, T* key, T* value, T* output,
int batch, int heads, int seq_len, int dim,
int block_size = 128
) {
int num_blocks = seq_len / block_size;
for (int b = 0; b < batch; b++) {
for (int h = 0; h < heads; h++) {
for (int qi = 0; qi < num_blocks; qi++) {
// 取Q分块
T* q_block = query + offset(b, h, qi * block_size, 0);
// 初始化在线softmax状态
typename OnlineSoftmax<T, 128>::State state;
state.row_max = -1e9;
state.row_sum = 0;
for (int ki = 0; ki < num_blocks; ki++) {
T* k_block = key + offset(b, h, ki * block_size, 0);
T* v_block = value + offset(b, h, ki * block_size, 0);
// 用catlass的GEMM模板算Q×K^T
GemmSplitK<T, T, T> gemm;
T local_scores[128 * 128];
gemm(q_block, k_block, local_scores, block_size, block_size, dim);
// 缩放
Scale(local_scores, 1.0 / sqrt(dim));
// 用catlass的在线softmax更新
OnlineSoftmax<T, 128>::Update(state, local_scores, v_block, block_size);
}
// 归一化输出
OnlineSoftmax<T, 128>::Finalize(state, seq_len);
// 写回output
StoreOutput(output + offset(b, h, qi * block_size, 0), state);
}
}
}
}
这就是FlashAttention的骨架:分块矩阵乘+在线softmax。catlass把底层的搬运、对齐、流水化全部封装在模板里,上层代码只需要关心"算什么"和"怎么累加"。
跟其他仓库的边界
catlass的职责边界需要搞清楚:
| 需求 | 去哪个仓库 |
|---|---|
| 调FlashAttention跑推理 | torch_npu |
| 改FlashAttention的算法逻辑 | ops-transformer |
| 改分块策略或softmax模板 | catlass |
| 写一个全新的算子 | catlass(模板)+ ops-*(注册) |
| 调图融合优化 | ge |
catlass只负责"怎么高效地在昇腾NPU上做基础运算",不涉及具体的算子逻辑和图调度。
##昇腾NPU上catlass的硬件适配
catlass的模板设计有三个昇腾特有的约束:
UB容量约束:Unified Buffer约256KB,分块大小必须适配。catlass的TileShape默认128×128×64,单个FP16分块约占32KB,5个分块(Q/K/V/Score/Output)加中间结果约160KB,留有余量
128字节对齐:昇腾NPU的DMA搬运要求数据起始地址128字节对齐。catlass的layout模块自动处理padding和对齐,开发者不需要手动算偏移
Cube和Vector的分工:矩阵乘走Cube单元(高吞吐),标量运算和softmax走Vector单元。catlass的GEMM模板默认走Cube,OnlineSoftmax走Vector,两者通过UB传递数据
克隆catlass仓库,先看examples/flash_attention/的示例代码,理解模板怎么搭出完整算子。然后读online_softmax.h的在线softmax实现,这是FlashAttention的数学基石。
更多推荐




所有评论(0)