【昇腾CANN】catlass 算子模板库深度解读:从 CUTLASS 到昇腾的 GEMM 优化之路
【昇腾CANN】catlass 算子模板库深度解读:从 CUTLASS 到昇腾的 GEMM 优化之路
### 【昇腾CANN】catlass 算子模板库深度解读:从 CUTLASS 到昇腾的 GEMM 优化之路
上周五下午,同事老王在昇腾 NPU 上跑一个推荐模型,矩阵乘法(GEMM)占了 73% 的耗时。他翻遍了昇腾官方文档,试了各种配置,延迟还是下不来。我看了一眼他的代码——他直接调的是 AscendCL 的 aclblasGemmEx,参数全用默认值。我跟他说:“你这不是在调优,你这是在摸彩票。”
问题在于,GEMM 是一个水很深的算子。表面上是 C = alpha * A * B + beta * C,但内存布局(RowMajor vs ColumnMajor)、数据精度(FP16 vs FP32 vs BF16)、分块大小(Tile Size)、向量化宽度(Vector Width),每一个参数组合都会让硬件利用率差出几倍。
catlass 就是干这个的——它是昇腾社区开源的算子模板库,基于 NVIDIA CUTLASS 2.x 的设计哲学,为昇腾 NPU 重新实现了一套 GEMM 优化模板。你不需要从头写汇编,只需要在模板里填几个参数,catlass 就能帮你生成接近硬件极限的矩阵乘法。
catlass 是什么:昇腾的算子模板库
先说清楚 catlass 和几个容易混淆的概念。catlass 不是 CUTLASS 的直接移植。CUTLASS 是 NVIDIA 的,它的代码是为 NVIDIA 的 Tensor Core 写的。catlass 是昇腾社区参考 CUTLASS 的设计思路,重新为昇腾达芬奇架构实现的版本。两者在概念上一致(都是分层 GEMM 模板),但代码完全不兼容。
简单来说,catlass 是昇腾 CANN 开源生态中的高性能矩阵乘模板库。它通过分层抽象、模块化设计与可组合模板,将 GEMM 的实现白盒化。开发者只需关注“做什么”,而非“怎么做”,使开发者无需成为底层专家,也能快速构建接近硬件理论峰值的定制化 GEMM 算子。
什么时候用 catlass?
- 你已经 profile 过了,发现 GEMM 是瓶颈。
- 你需要融合多个操作进一个算子(比如 GEMM + BiasAdd + ReLU 进一个 kernel)。
- 你的数据布局或精度组合 AscendCL 没支持。
什么时候不要用?
- AscendCL 的
aclblasGemmEx已经够用(大多数场景够用)。 - 你对性能没有极致要求。
- 你不熟悉 C++ 模板元编程。
核心概念:GEMM 的分层抽象
catlass 把 GEMM 的计算过程拆成了五层,每层抽象一个粒度的并行:
- Level-5: Host(参数配置 & 内存分配)
- Level-4: ThreadBlock(Tile调度)
- Level-3: Warp(向量化计算)
- Level-2: Thread(寄存器级操作)
- Level-1: TensorOp(指令级硬件加速)
Level-1 是最小的计算单元,对应昇腾 NPU 的向量指令。Level-2 是一个线程处理多个 TensorOp。Level-3 是 32 个线程为一组,协同执行更高效的向量化操作。Level-4 是一个线程块调度多个 warp,负责把 GEMM 的结果累加到输出矩阵的某个子块。Level-5 则是在 CPU 端配置参数、分配设备内存、发起核函数调用。
GEMM 分块:把大矩阵塞进 SRAM
GEMM 最核心的优化思路是分块(Tiling)。昇腾 NPU 有两层内存:HBM(显存,容量大但带宽相对低)和 SRAM(片上缓存,容量小但带宽极高)。标准 GEMM 把整个 A/B/C 矩阵放在 HBM 里,每次计算都要从 HBM 读 A 的一行和 B 的一列。如果矩阵很大,带宽直接打满,NPU 算力反而闲置。
catlass 的分块策略是:把 A 和 B 切成小块(Tile),每次只把一个小块从 HBM 读到 SRAM,算完就扔,不写回 HBM,直到这个 tile 对应的所有计算完成。
// gemm_kernel.h(catlass 模板核心逻辑)
template <
typename ElementA, // A矩阵元素类型
typename ElementB, // B矩阵元素类型
typename ElementC, // C矩阵元素类型(累加器)
typename ThreadBlockShape, // ThreadBlock 大小,如 128x128
typename WarpShape, // Warp 大小,如 64x64
typename InstructionShape, // TensorOp 大小,如 8x8
typename LayoutA,
typename LayoutB
>
__global__ void GemmKernel(
ElementA* A, ElementB* B, ElementC* C,
ElementC alpha, ElementC beta,
int M, int N, int K
) {
// 1. 声明 SRAM 缓冲区(编译期确定大小)
// __shared__ 关键字表示放在片上共享内存(SRAM)
__shared__ ElementA smemA[ThreadBlockShape::kM][ThreadBlockShape::kK];
__shared__ ElementB smemB[ThreadBlockShape::kK][ThreadBlockShape::kN];
// 2. 初始化累加器(用ElementC类型,累加器通常是float)
FragmentC accumulators;
accumulators.clear(); // 全零初始化
// 3. 外层循环:按K维度切块,逐步把K轴的数据读到SRAM
for (int k = 0; k < K; k += ThreadBlockShape::kK) {
// 3.1 从HBM异步拷贝 A 的 tile 到 SRAM
AsyncCopy(A + threadIdx.x * ..., smemA[threadIdx.y][threadIdx.x], copyParamsA);
// 3.2 从HBM异步拷贝 B 的 tile 到 SRAM
AsyncCopy(B + threadIdx.x * ..., smemB[threadIdx.x][threadIdx.y], copyParamsB);
// 3.3 等待拷贝完成(同步点)
__syncthreads();
// 3.4 在SRAM上执行矩阵乘法(分小块进 TensorOp)
#pragma unroll
for (int kInner = 0; kInner < WarpShape::kK; ++kInner) {
compute_mma(accumulators, smemA, smemB, kInner);
}
// 3.5 等待计算完成,准备下一块
__syncthreads();
}
// 4. 把累加结果从寄存器写回HBM
WriteC(accumulators, C + blockIdx.x * ThreadBlockShape::kM, ...);
}
代码解读(解释 WHY 而非 WHAT):
__shared__声明的数组在 SRAM 上,和 HBM 相比带宽差了 40 倍。不在这里反复读写,硬件根本跑不满。- 外层循环按 K 轴切块,每次只加载
ThreadBlockShape::kK列的数据——这是控制 SRAM 占用和 HBM 带宽的杠杆。 AsyncCopy是异步的,拷贝和计算可以流水线并行——当 Warp 1 在用当前 A/B tile 计算时,Warp 2 同时在拷贝下一块 A/B。#pragma unroll手动展开内层循环,让编译器做更好的指令调度,减少循环开销。
性能调参:Tile 大小怎么选?
catlass 调优最核心的问题是:ThreadBlockShape 和 WarpShape 怎么选?这是一个由硬件约束驱动的选择题,不是调出来的,是查出来的。
昇腾 Ascend 910 的 SRAM 容量是 192KB/AI Core。一个 ThreadBlock 需要的 SRAM 空间计算公式为:SRAM 占用 = ThreadBlockShape::kM × ThreadBlockShape::kK × sizeof(A元素) + ThreadBlockShape::kK × ThreadBlockShape::kN × sizeof(B元素) + 寄存器文件
如果选 ThreadBlockShape = 128×128,FP16 精度:smemA = 128 × 128 × 2 = 32KB,smemB = 128 × 128 × 2 = 32KB,合计 64KB,加上寄存器溢出,一个 AI Core 绰绰有余。但如果选 ThreadBlockShape = 256×256,合计 256KB,这已经超过 192KB 的 SRAM 上限,编译会直接报错。
实际可用的组合需要同时满足:
smemA + smemB < SRAM 容量 × 0.85(留 15% 给其他用途)WarpShape::kM × WarpShape::kN是ThreadBlockShape的因数(能被整除)InstructionShape是WarpShape的因数
一个保守的可工作配置(Ascend 910 FP16):
using ThreadBlockShape = cutlass::gemm::threadblock::GemmShape<128, 128, 64>;
using WarpShape = cutlass::gemm::warp::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::warp::GemmShape<8, 8, 16>;
从模板到实战:写一个融合 GEMM+Bias+ReLU
catlass 真正的威力在算子融合。如果你有 GEMM + BiasAdd + ReLU 三个操作,用 AscendCL 要分三次调用,涉及三次 HBM 读写和两次额外的 kernel launch 开销。用 catlass 可以一次搞定:
// 方式2:catlass融合算子(高效)
template <typename GemmConfig, typename EpilogueOp>
__global__ void GemmBiasReLUKernel(
ElementA* A, ElementB* B, ElementC* C,
ElementBias* bias, // 额外的bias指针
int M, int N, int K
) {
// 1. 声明共享内存和累加器
__shared__ ElementA smemA[...];
__shared__ ElementB smemB[...];
FragmentC accumulators;
// 2. GEMM阶段(和之前一样)
for (int k = 0; k < K; k += kTile) {
AsyncCopy(...);
__syncthreads();
compute_mma(accumulators, smemA, smemB);
__syncthreads();
}
// 3. Epilogue阶段:在寄存器里做 Bias + ReLU
// 注意:这一步不需要额外的HBM读写!
ElementC bias_val = load(bias + blockIdx.x * N); // 只读一次bias
ElementC output = accumulators + bias_val; // 寄存器级加法
output = max(output, 0); // ReLU,寄存器级
WriteC(output, C + ...); // 直接写最终结果
}
为什么融合后快?因为 BiasAdd + ReLU 不走 HBM,直接在寄存器里完成,减少了 HBM 访问和额外的 kernel launch 开销。通过 catlass 的模板化设计,开发者只需简单配置几个关键参数,就能自动生成针对特定硬件优化的 GEMM 算子,将传统需要数周的算子开发周期缩短至 1-3 天。
上周五下午,同事老王在昇腾 NPU 上跑一个推荐模型,矩阵乘法(GEMM)占了 73% 的耗时。他翻遍了昇腾官方文档,试了各种配置,延迟还是下不来。我看了一眼他的代码——他直接调的是 AscendCL 的 aclblasGemmEx,参数全用默认值。我跟他说:“你这不是在调优,你这是在摸彩票。”
问题在于,GEMM 是一个水很深的算子。表面上是 C = alpha * A * B + beta * C,但内存布局(RowMajor vs ColumnMajor)、数据精度(FP16 vs FP32 vs BF16)、分块大小(Tile Size)、向量化宽度(Vector Width),每一个参数组合都会让硬件利用率差出几倍。
catlass 就是干这个的——它是昇腾社区开源的算子模板库,基于 NVIDIA CUTLASS 2.x 的设计哲学,为昇腾 NPU 重新实现了一套 GEMM 优化模板。你不需要从头写汇编,只需要在模板里填几个参数,catlass 就能帮你生成接近硬件极限的矩阵乘法。
catlass 是什么:昇腾的算子模板库
先说清楚 catlass 和几个容易混淆的概念。catlass 不是 CUTLASS 的直接移植。CUTLASS 是 NVIDIA 的,它的代码是为 NVIDIA 的 Tensor Core 写的。catlass 是昇腾社区参考 CUTLASS 的设计思路,重新为昇腾达芬奇架构实现的版本。两者在概念上一致(都是分层 GEMM 模板),但代码完全不兼容。
简单来说,catlass 是昇腾 CANN 开源生态中的高性能矩阵乘模板库。它通过分层抽象、模块化设计与可组合模板,将 GEMM 的实现白盒化。开发者只需关注“做什么”,而非“怎么做”,使开发者无需成为底层专家,也能快速构建接近硬件理论峰值的定制化 GEMM 算子。
什么时候用 catlass?
- 你已经 profile 过了,发现 GEMM 是瓶颈。
- 你需要融合多个操作进一个算子(比如 GEMM + BiasAdd + ReLU 进一个 kernel)。
- 你的数据布局或精度组合 AscendCL 没支持。
什么时候不要用?
- AscendCL 的
aclblasGemmEx已经够用(大多数场景够用)。 - 你对性能没有极致要求。
- 你不熟悉 C++ 模板元编程。
核心概念:GEMM 的分层抽象
catlass 把 GEMM 的计算过程拆成了五层,每层抽象一个粒度的并行:
- Level-5: Host(参数配置 & 内存分配)
- Level-4: ThreadBlock(Tile调度)
- Level-3: Warp(向量化计算)
- Level-2: Thread(寄存器级操作)
- Level-1: TensorOp(指令级硬件加速)
Level-1 是最小的计算单元,对应昇腾 NPU 的向量指令。Level-2 是一个线程处理多个 TensorOp。Level-3 是 32 个线程为一组,协同执行更高效的向量化操作。Level-4 是一个线程块调度多个 warp,负责把 GEMM 的结果累加到输出矩阵的某个子块。Level-5 则是在 CPU 端配置参数、分配设备内存、发起核函数调用。
GEMM 分块:把大矩阵塞进 SRAM
GEMM 最核心的优化思路是分块(Tiling)。昇腾 NPU 有两层内存:HBM(显存,容量大但带宽相对低)和 SRAM(片上缓存,容量小但带宽极高)。标准 GEMM 把整个 A/B/C 矩阵放在 HBM 里,每次计算都要从 HBM 读 A 的一行和 B 的一列。如果矩阵很大,带宽直接打满,NPU 算力反而闲置。
catlass 的分块策略是:把 A 和 B 切成小块(Tile),每次只把一个小块从 HBM 读到 SRAM,算完就扔,不写回 HBM,直到这个 tile 对应的所有计算完成。
// gemm_kernel.h(catlass 模板核心逻辑)
template <
typename ElementA, // A矩阵元素类型
typename ElementB, // B矩阵元素类型
typename ElementC, // C矩阵元素类型(累加器)
typename ThreadBlockShape, // ThreadBlock 大小,如 128x128
typename WarpShape, // Warp 大小,如 64x64
typename InstructionShape, // TensorOp 大小,如 8x8
typename LayoutA,
typename LayoutB
>
__global__ void GemmKernel(
ElementA* A, ElementB* B, ElementC* C,
ElementC alpha, ElementC beta,
int M, int N, int K
) {
// 1. 声明 SRAM 缓冲区(编译期确定大小)
// __shared__ 关键字表示放在片上共享内存(SRAM)
__shared__ ElementA smemA[ThreadBlockShape::kM][ThreadBlockShape::kK];
__shared__ ElementB smemB[ThreadBlockShape::kK][ThreadBlockShape::kN];
// 2. 初始化累加器(用ElementC类型,累加器通常是float)
FragmentC accumulators;
accumulators.clear(); // 全零初始化
// 3. 外层循环:按K维度切块,逐步把K轴的数据读到SRAM
for (int k = 0; k < K; k += ThreadBlockShape::kK) {
// 3.1 从HBM异步拷贝 A 的 tile 到 SRAM
AsyncCopy(A + threadIdx.x * ..., smemA[threadIdx.y][threadIdx.x], copyParamsA);
// 3.2 从HBM异步拷贝 B 的 tile 到 SRAM
AsyncCopy(B + threadIdx.x * ..., smemB[threadIdx.x][threadIdx.y], copyParamsB);
// 3.3 等待拷贝完成(同步点)
__syncthreads();
// 3.4 在SRAM上执行矩阵乘法(分小块进 TensorOp)
#pragma unroll
for (int kInner = 0; kInner < WarpShape::kK; ++kInner) {
compute_mma(accumulators, smemA, smemB, kInner);
}
// 3.5 等待计算完成,准备下一块
__syncthreads();
}
// 4. 把累加结果从寄存器写回HBM
WriteC(accumulators, C + blockIdx.x * ThreadBlockShape::kM, ...);
}
代码解读(解释 WHY 而非 WHAT):
__shared__声明的数组在 SRAM 上,和 HBM 相比带宽差了 40 倍。不在这里反复读写,硬件根本跑不满。- 外层循环按 K 轴切块,每次只加载
ThreadBlockShape::kK列的数据——这是控制 SRAM 占用和 HBM 带宽的杠杆。 AsyncCopy是异步的,拷贝和计算可以流水线并行——当 Warp 1 在用当前 A/B tile 计算时,Warp 2 同时在拷贝下一块 A/B。#pragma unroll手动展开内层循环,让编译器做更好的指令调度,减少循环开销。
性能调参:Tile 大小怎么选?
catlass 调优最核心的问题是:ThreadBlockShape 和 WarpShape 怎么选?这是一个由硬件约束驱动的选择题,不是调出来的,是查出来的。
昇腾 Ascend 910 的 SRAM 容量是 192KB/AI Core。一个 ThreadBlock 需要的 SRAM 空间计算公式为:SRAM 占用 = ThreadBlockShape::kM × ThreadBlockShape::kK × sizeof(A元素) + ThreadBlockShape::kK × ThreadBlockShape::kN × sizeof(B元素) + 寄存器文件
如果选 ThreadBlockShape = 128×128,FP16 精度:smemA = 128 × 128 × 2 = 32KB,smemB = 128 × 128 × 2 = 32KB,合计 64KB,加上寄存器溢出,一个 AI Core 绰绰有余。但如果选 ThreadBlockShape = 256×256,合计 256KB,这已经超过 192KB 的 SRAM 上限,编译会直接报错。
实际可用的组合需要同时满足:
smemA + smemB < SRAM 容量 × 0.85(留 15% 给其他用途)WarpShape::kM × WarpShape::kN是ThreadBlockShape的因数(能被整除)InstructionShape是WarpShape的因数
一个保守的可工作配置(Ascend 910 FP16):
using ThreadBlockShape = cutlass::gemm::threadblock::GemmShape<128, 128, 64>;
using WarpShape = cutlass::gemm::warp::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::warp::GemmShape<8, 8, 16>;
从模板到实战:写一个融合 GEMM+Bias+ReLU
catlass 真正的威力在算子融合。如果你有 GEMM + BiasAdd + ReLU 三个操作,用 AscendCL 要分三次调用,涉及三次 HBM 读写和两次额外的 kernel launch 开销。用 catlass 可以一次搞定:
// 方式2:catlass融合算子(高效)
template <typename GemmConfig, typename EpilogueOp>
__global__ void GemmBiasReLUKernel(
ElementA* A, ElementB* B, ElementC* C,
ElementBias* bias, // 额外的bias指针
int M, int N, int K
) {
// 1. 声明共享内存和累加器
__shared__ ElementA smemA[...];
__shared__ ElementB smemB[...];
FragmentC accumulators;
// 2. GEMM阶段(和之前一样)
for (int k = 0; k < K; k += kTile) {
AsyncCopy(...);
__syncthreads();
compute_mma(accumulators, smemA, smemB);
__syncthreads();
}
// 3. Epilogue阶段:在寄存器里做 Bias + ReLU
// 注意:这一步不需要额外的HBM读写!
ElementC bias_val = load(bias + blockIdx.x * N); // 只读一次bias
ElementC output = accumulators + bias_val; // 寄存器级加法
output = max(output, 0); // ReLU,寄存器级
WriteC(output, C + ...); // 直接写最终结果
}
为什么融合后快?因为 BiasAdd + ReLU 不走 HBM,直接在寄存器里完成,减少了 HBM 访问和额外的 kernel launch 开销。通过 catlass 的模板化设计,开发者只需简单配置几个关键参数,就能自动生成针对特定硬件优化的 GEMM 算子,将传统需要数周的算子开发周期缩短至 1-3 天。
更多推荐



所有评论(0)