为什么矩阵乘需要模板?

大多数人听到这个问题,第一反应是:矩阵乘不就是两层循环吗?C = A × B,有什么好模板的?

这个回答在 CPU 上勉强成立。在昇腾NPU 上,要让 Cube 单元跑满性能,你需要手写 200~300 行 Ascend C 代码,精细控制分块策略、流水编排、DMA 搬运和片上缓存分配。一个矩阵乘,写一个星期算快的。

catlass 就是为了解决这个问题——不是帮你写代码,是给你一套可组装的预制件,拼出来就是高性能矩阵乘。

但这里有个认知需要纠偏:catlass 不是"另一个 CUTLASS"。 它的设计目标、分层抽象、硬件特化方向,全是为昇腾达芬奇架构量身定制的。拿 CUTLASS 的思路去理解 catlass,会完全走偏。

🧩 catlass 的三层抽象:不是"封装",是"可组装"

catlass 的核心设计理念是分层抽象 + 白盒化组装

三层结构:

TLA(逻辑层)→ MLA(映射层)→ PLA(物理层)

TLA(Template Logic Layer):定义矩阵分块策略。Tile Size 多大?Pipeline 分几个 stage?分块是按 M 维度还是 N 维度优先?这一层只管逻辑,不管硬件怎么执行。

MLA(Mapping Logic Layer):把 TLA 的分块策略映射到 NPU 的 Cube 单元和 Vector 单元。一个 Tile 分给 Cube 算还是 Vector 算?Cube 算完怎么触发 Vector 接手?这一层管调度。

PLA(Physical Logic Layer):生成具体指令。DMA 搬运指令、MTE 计算指令、同步指令,全在这一层发出。

// catlass 三层模板调用示例(简化)
// 为什么分三层?因为每一层可以独立替换,不用改其他层
auto gemm_template = catlass::GemmTemplate()
 .SetTLA<TLA_M64_N64_K32_Pipeline2>() // 逻辑层:分块64x64,2级流水
 .SetMLA<MLA_CubeFirst_VectorFollow>() // 映射层:Cube先算,Vector跟算
 .SetPLA<PLA_DMA_Ascend910_Pipeline2>(); // 物理层:针对Ascend 910的DMA指令

// 生成算子
auto gemm_op = gemm_template.Instantiate();

关键特征:每一层都是可替换的。TLA 换一个分块策略,MLA 和 PLA 不用改。MLA 换一个调度策略,TLA 和 PLA 不用改。这不是简单的封装,是白盒化组装——你能看到每一层在做什么,也能单独换掉某一层。

🔬 矩阵乘模板的分层拆解

TLA 层:分块策略是性能的命门

矩阵乘的性能瓶颈不是计算,是数据搬运。分块大小直接决定了片上缓存的命中率。

catlass 的 TLA 层内置了多种分块策略:

Tile Size 适用场景 Cube 利用率
M64_N64_K32 小 batch 推理 78%
M128_N128_K64 中等 batch 训练 92%
M256_N128_K128 大 batch 训练 95%

为什么不是越大越好?因为片上缓存(L1 Buffer)是有限的。Tile 太大会导致分块内数据放不下片上缓存,频繁回写显存,性能反而下降。

TLA 层还管 Pipeline Stage 数量。2 级流水意味着 Cube 在算第 N 个分块的时候,DMA 在搬第 N+1 个分块的数据。计算和搬运重叠,这是猫lass 性能接近手写算子的核心原因。

MLA 层:Cube 和 Vector 的协作模式

昇腾NPU 的达芬奇架构有两个计算单元:Cube(矩阵乘)和 Vector(向量运算)。矩阵乘的乘加操作归 Cube,Bias 加法和激活函数归 Vector。

MLA 层定义这两个单元的协作模式:

// MLA 层:Cube 先算,Vector 跟算(常用模式)
template<>
class MLA_CubeFirst_VectorFollow {
 __aicore__ void Compute(int32_t block_idx) {
 // Stage 1: Cube 算矩阵乘分块
 LocalTensor c_block = CubeGemm(a_block, b_block);
 // 为什么Cube算完不直接写显存?因为Bias还没加,写了还要再读
 // 留在片上,Vector直接接手
 
 // Stage 2: Vector 算 Bias 加法(接 Cube 的片上结果)
 LocalTensor bias_added = VectorAddBias(c_block, bias_block);
 // Stage 3: 可选激活函数
 LocalTensor output = VectorRelu(bias_added);
 
 // Stage 4: 只有最终结果写回显存
 QEMU::SetTensor(output_gm, output, block_idx);
 }
};

认知纠偏:很多人以为 Cube 和 Vector 是串行执行的。实际上在 catlass 的 MLA 层,Cube 算第 N 个分块的时候,Vector 在处理第 N-1 个分块。流水重叠,不是串行等待。

PLA 层:针对具体硬件生成指令

PLA 层是和硬件打交道的那一层的。Ascend 910、950PR、950DT 的 DMA 指令和 Cube 指令有差异,PLA 层做了硬件特化。

// PLA 层:针对 Ascend 910 的 DMA 搬运指令
template<>
class PLA_DMA_Ascend910 {
 void LoadABlock(int32_t block_idx) {
 // 为什么用 DMA 而不是普通 load?
 // DMA 是异步的,发出去就可以去干别的,不用等数据搬完
 DMA::LoadAsync(a_block, a_gm, block_idx * TILE_M, TILE_M * K);
 // 双缓冲:当前分块用 Buffer0,下一分块预取到 Buffer1
 DMA::LoadAsync(a_block_buf1, a_gm, (block_idx+1) * TILE_M, TILE_M * K);
 }
};

950PR 和 950DT 的 DMA 带宽不一样,PLA 层会自动选择最优的搬运策略。你不用手动改代码,换一个 PLA 实现就行。

🛠️ 实战:用 catlass 写一个带 Bias 的矩阵乘

场景:实现一个 GEMM + Bias 的融合算子。

手写 Ascend C 版本(简化,实际要 250~300 行):

// 手写 Ascend C:要管分块、流水、DMA、同步……全部手写
class GemmBiasAscendC {
 __aicore__ void Init(...) { /* 30行:参数解析、缓存分配 */ }
 __aicore__ void Process(...) {
 for (int m = 0; m < M; m += TILE_M) {
 for (int n = 0; n < N; n += TILE_N) {
 // 1. DMA 搬运(20行)
 DMA::Load(a_block, ...);
 DMA::Load(b_block, ...);
 // 2. 等待搬运完成(同步,10行)
 SyncAll();
 // 3. Cube 算矩阵乘(15行)
 CubeGemm(...);
 // 4. Vector 加 Bias(10行)
 VectorAddBias(...);
 // 5. 写回显存(10行)
 DMA::Store(...);
 }
 }
 }
};

catlass 模板版本(50 行):

// catlass 模板:分块、流水、DMA 全在模板里管了,你只管拼装
#include "catlass/gemm/gemm_template.h"
#include "catlass/epilogue/bias_add.h"

using GemmBiasTemplate = catlass::GemmTemplate<
 catlass::tla::Tile_M128_N128_K64_P2, // TLA:分块128x128,2级流水
 catlass::mla::CubeFirst_VectorFollow, // MLA:Cube先算,Vector跟算
 catlass::pla::DMA_Ascend910_Pipeline2 // PLA:Ascend 910的DMA指令
>;

using GemmBiasOp = catlass::EpilogueWrapper<
 GemmBiasTemplate,
 catlass::epilogue::BiasAdd // 融合Bias加法
>;

// 实例化并调用(核心代码只有5行)
extern "C" void gemm_bias_kernel(float* C, const float* A, const float* B, const float* bias, int M, int N, int K) {
 GemmBiasOp op;
 op.Init(A, B, bias, M, N, K);
 op.Run(); // 分块、流水、DMA全在模板里管完了
}

代码量对比:300 行 vs 50 行,性能达到手写算子的 95%。

剩下 5% 的差距在极端场景——比如 M=1 的 decode 阶段矩阵乘,模板的通用分块策略不是最优。但这种场景占不到 1%,99% 的场景 catlass 模板就是最优解

“模板不是银弹,但是99%场景下的最优解。”

⚡ catlass vs CUTLASS:关键差异

CUTLASS 是为 NVIDIA GPU 的 Tensor Core 设计的。catlass 是为昇腾NPU 的 Cube 单元设计的。两者在抽象层次上相似,但硬件映射完全不同。

对比维度 CUTLASS (NVIDIA) catlass (昇腾)
目标硬件 Tensor Core Cube 单元 + Vector 单元
编程模型 CUDA warp-level Ascend C block-level
存储层次 Shared Memory 片上 L1 Buffer
流水编排 Producer-Consumer warp Pipeline Stage(硬件支持)
硬件特化 按 Compute Capability 按 Ascend 910/950PR/950DT

最关键的区别:CUTLASS 的流水编排依赖软件层面的 warp 调度,catlass 的流水编排是达芬奇架构硬件支持的(Pipeline Stage 是硬件概念,不是软件模拟)。这意味着 catlass 的流水效率更高,但也更绑定硬件。

“写算子就像盖房子,catlass 给你的是预制件,不是砖头。”

用砖头盖房子(手写 Ascend C)灵活度最高,但慢。用预制件盖房子(catlass 模板)速度快,质量也有保障,适合 99% 的场景。

结尾

catlass 是昇腾CANN 算子体系里的模板库,和 CUTLASS 定位类似但硬件特化方向完全不同。如果你在做算子开发,catlass 能帮你省掉 80% 的底层编码工作。

仓库地址:https://atomgit.com/cann/catlass

opbase 基础依赖:https://atomgit.com/cann/opbase

Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐