CUTLASS 是 NVIDIA 的 GEMM 模板库,用 C++ 模板元编程生成高性能的矩阵乘 kernel。catlass 名字里也有 “c” + “atlass”,很容易让人以为它是昇腾版的 CUTLASS。

不是。catlass 和 CUTLASS 的定位不一样。

CUTLASS 是通用的 GEMM 模板库,目标是生成各种形状的矩阵乘 kernel。catlass 的目标更具体:为昇腾 NPU 的 Cube 单元生成最优的矩阵乘、卷积、Transformer 相关算子的 kernel 模板。 它生成的不是某一个算子,而是一族算子的实现模板。

这个区分很重要。理解了这一点,你才能看懂 catlass 的代码结构,以及它和 ops-transformer 的关系——ops-transformer 里的 FlashAttention kernel,底层的 Q×K^T 和 P×V 矩阵乘,可以用 catlass 生成的 GEMM kernel 来加速。

catlass 的核心设计:模板元编程 + 硬件参数化

catlass 的代码大量使用 C++ 模板元编程。第一次看可能会觉得代码很绕,但设计目标很明确:把"调优"这件事从手写 kernel 变成配置问题。

手写一个高性能的 GEMM kernel,需要考虑的参数有:

  • Tile 大小(M/N/K 三个维度)
  • 流水线级数(double buffering 还是 triple buffering)
  • 数据布局(RowMajor 还是 ColMajor)
  • 数据类型(fp16、bf16、fp32)
  • 硬件参数(Cube 单元的 M/N/K 维度、UB 大小、L1 Buffer 大小)

这些参数的组合空间非常大。手写一个最优配置需要大量实验。catlass 的思路是:把这些参数全部模板化,让编译器在编译期枚举所有可能的配置,然后通过 benchmark 选出最快的那个。

// catlass 的 GEMM 模板(简化版)
// 文件:catlass/include/gemm/gemm_template.h

template <
 typename ElementA, // A 的数据类型
 typename ElementB, // B 的数据类型
 typename ElementC, // C 的数据类型
 int TileM, // Tile 的 M 维度
 int TileN, // Tile 的 N 维度
 int TileK, // Tile 的 K 维度
 int Stages, // 流水线级数
 typename LayoutA, // A 的布局
 typename LayoutB, // B 的布局
 typename ArchTag // 硬件架构标签(Ascend910 / Ascend910B / ...)
>
class GemmTemplate {
public:
 // 核心计算:C = A × B + C
 void operator()(ElementA* A, ElementB* B, ElementC* C,
 int M, int N, int K, cudaStream_t stream);
};

为什么用模板而不是运行时参数? 因为 Tile 大小、Stages、Layout 这些参数在编译期确定后,编译器可以做大量的优化:

  1. 循环展开:TileM/TileN/TileK 是编译期常量,内层循环可以完全展开
  2. 寄存器分配:编译器知道需要多少寄存器,可以做最优分配
  3. 指令调度:编译器知道 Cube 单元的延迟,可以插入合适的 pipeline 指令

如果把这些参数改成运行时传参,这些优化全部消失,性能会掉 30-50%。

catlass 的自动调优:枚举 + benchmark

catlass 提供了一个自动调优工具,核心思路是 枚举所有合法的模板参数组合,然后跑 benchmark 选最快的

# catlass 的自动调优脚本(概念性)
# 文件:catlass/tools/auto_tune.py

import itertools

# 定义搜索空间
tile_m_options = [64, 128, 256]
tile_n_options = [64, 128, 256]
tile_k_options = [16, 32, 64]
stages_options = [2, 3]
dtype_options = [torch.fp16, torch.bf16]

# 枚举所有组合
configs = list(itertools.product(
 tile_m_options, tile_n_options, tile_k_options,
 stages_options, dtype_options
))

# benchmark 每个配置
best_time = float('inf')
best_config = None

for config in configs:
 tile_m, tile_n, tile_k, stages, dtype = config
 
 # 实例化模板(编译)
 gemm_kernel = GemmTemplate(
 ElementA=dtype, ElementB=dtype, ElementC=dtype,
 TileM=tile_m, TileN=tile_n, TileK=tile_k,
 Stages=stages, LayoutA=RowMajor, LayoutB=ColMajor,
 ArchTag=Ascend910
 )
 
 # benchmark
 time_ms = benchmark(gemm_kernel, M=4096, N=4096, K=4096)
 
 if time_ms < best_time:
 best_time = time_ms
 best_config = config

print(f"Best config: {best_config}, time: {best_time} ms")

这个脚本跑一次可能需要几个小时(因为要编译几百个模板实例)。但跑完之后,你会得到一个最优配置表,以后直接用这个配置就行,不需要重新调优。

catlass 的仓库里已经预置了常见形状的最优配置(比如 Llama 3 70B 的 Q×K^T 形状是 [batch, num_q_heads, seq_len, head_dim]),直接用就行。

catlass 和 FlashAttention 的关系

FlashAttention 的核心计算是两个矩阵乘:

  1. Q×K^T:计算注意力分数([batch, num_heads, seq_len, seq_len])
  2. P×V:加权求和([batch, num_heads, seq_len, head_dim])

这两个矩阵乘的形状比较特殊:

  • Q×K^T 的 K 维度是 head_dim(通常 128),比较小
  • P×V 的 N 维度是 head_dim(128),也比较小
  • 但 M 维度是 seq_len × num_heads,可以很大(比如 8192 × 32 = 262144)

标准的 GEMM 模板可能不适合这种形状。 catlass 提供了一个专门的 BatchGEMM 模板,处理 batch 维度在外的矩阵乘(FlashAttention 的 batch 维度是 batch × num_heads)。

// catlass 的 BatchGEMM 模板(简化版)
// 文件:catlass/include/gemm/batch_gemm_template.h

template <
 typename ElementA,
 typename ElementB,
 typename ElementC,
 int Batch, // batch 大小
 int TileM,
 int TileN,
 int TileK,
 int Stages,
 typename LayoutA,
 typename LayoutB,
 typename ArchTag
>
class BatchGemmTemplate {
public:
 // 核心计算:C[b] = A[b] × B[b] + C[b], for b in [0, Batch)
 void operator()(ElementA* A, ElementB* B, ElementC* C,
 int M, int N, int K, cudaStream_t stream);
};

BatchGEMM 和 GEMM 的区别:GEMM 是把 batch 维度展开成 M 维度([Batch×M, N] × [N, K] = [Batch×M, K]),BatchGEMM 是真正的 batch 矩阵乘(每个 batch 独立计算)。

FlashAttention 用的是 BatchGEMM(因为每个 head 的注意力分数是独立的)。

catlass 的卷积模板

除了 GEMM,catlass 还提供了卷积模板。卷积可以展开成矩阵乘(im2col),然后调用 GEMM 模板。

// catlass 的卷积模板(简化版)
// 文件:catlass/include/conv/conv_template.h

template <
 typename ElementA,
 typename ElementB,
 typename ElementC,
 int TileM,
 int TileN,
 int TileK,
 int Stages,
 typename LayoutA,
 typename LayoutB,
 typename ArchTag
>
class ConvTemplate {
public:
 // 核心计算:Y = conv(X, W)
 // 内部实现:im2col(X) → X_col, reshape(W) → W_col, then GEMM
 void operator()(ElementA* X, ElementB* W, ElementC* Y,
 int N, int C, int H, int W, int R, int S,
 cudaStream_t stream);
};

为什么不用直接的卷积 kernel? 因为 im2col + GEMM 的性能通常比直接卷积 kernel 更好(GEMM 有成熟的优化,卷积没有)。catlass 的卷积模板就是 im2col + GEMM 的封装。

可以直接看 catlass 的源码吗

catlass 是 CANN 的开源组件,源码在 https://atomgit.com/cann/catlass 。

catlass 的代码比较偏底层(大量模板元编程),建议先从一个具体的例子看起(比如 catlass/examples/gemm_example.cpp),理解怎么用 catlass 写一个 GEMM kernel,然后再深入模板的实现。

catlass 仓库地址:https://atomgit.com/cann/catlass

cann-learning-hub(catlass 使用指南):https://atomgit.com/cann/cann-learning-hub

Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐