CUTLASS 是 NVIDIA 的矩阵乘模板库,catlass 是昇腾的对应物——用 C++ 模板元编程在编译期生成算子,运行时零开销。核心思路:把算子拆成可组合的模板参数,编译期决定一切(tile 大小、数据布局、指令选择),运行期只做数据搬运和计算。

为什么不用普通函数?GEMM 有 15 个可调参数(tile M/N/K、向量化宽度、是否用 TF32、是否 fuse ReLU)——15 维参数空间,手写 15 个特化版本不现实。模板元编程让编译器帮你生成。

核心抽象:Tile Iterator

catlass 的一切围绕 TileIterator——一个编译期知道 tile 大小、数据布局、向量化宽度的迭代器。它不是 runtime 的 for 循环,是编译期的类型计算。

// catlass/include/cutlass/tile_itererator.h

// 模板参数:
//   - Shape_MNK:tile 的 M/N/K 大小(编译期常量)
//   - Element:数据类型(float16/float32/bfloat16)
//   - Layout:数据布局(RowMajor/ColumnMajor/VoltaLayout)
//   - VectorWidth:向量化加载宽度(4/8/16 个元素一次加载)

template <
    typename Shape_MNK,      // cutlass::MatrixShape<128, 128, 32>
    typename Element,         // float16
    typename Layout,          // cutlass::layout::RowMajor
    int VectorWidth = 8      // 一次加载 8 个 float16 = 128 bits
>
class TileIterator {
public:
    using Shape = Shape_MNK;
    using AccessType = Array<Element, VectorWidth>;  // 向量化访问类型

    // 编译期计算:一个 tile 需要几次向量化加载
    static constexpr int kAccessCount =
        Shape::kMN / VectorWidth * Shape::kN / VectorWidth *
        Shape::kK / VectorWidth;

    // 构造函数:绑定到全局内存的基地址
    CUTLASS_DEVICE
    TileIterator(Element* ptr, int stride)
        : ptr_(reinterpret_cast<AccessType*>(ptr)),
          stride_(stride / VectorWidth) {}

    // 加载一个 tile 到寄存器(向量化)
    CUTLASS_DEVICE
    void Load(Tile<Shape, Element>& tile, int m, int n, int k) {
        // 编译期展开:kAccessCount 次向量化加载
        CUTLASS_PRAGMA(unroll)
        for (int i = 0; i < kAccessCount; ++i) {
            // 地址计算:编译期决定,无运行时开销
            AccessType* addr = ptr_ + (m * stride_ + n) + i;
            tile.data[i] = *addr;  // 向量化加载(128 bits 一次)
        }
    }

    // 写回一个 tile
    CUTLASS_DEVICE
    void Store(const Tile<Shape, Element>& tile, int m, int n, int k) {
        CUTLASS_PRAGMA(unroll)
        for (int i = 0; i < kAccessCount; ++i) {
            AccessType* addr = ptr_ + (m * stride_ + n) + i;
            *addr = tile.data[i];
        }
    }

private:
    AccessType* ptr_;  // 向量化指针(不是 Element*)
    int stride_;         // 以 AccessType 为单位的 stride
};

CUTLASS_PRAGMA(unroll) 让编译器把循环完全展开——最终生成的代码没有循环,只有 16 条 load128(向量化加载指令)。这是"零成本抽象"的含义:模板代码写起来像泛型,编译后和手写汇编一样高效。

算子融合的模板实现:GEMM + Bias + ReLU

独立的 GEMM 算子:C = A × B。融合版本:C = ReLU(A × B + Bias)。不用融合的话需要三个 kernel:GEMM → AddBias → ReLU,两次 HBM 往返。

catlass 的融合在模板层面完成——不是 runtime 的 if,是编译期生成专用的融合 kernel。

// catlass/include/cutlass/gemm/kernel/gemm_with_fusion.h

// 融合策略:模板参数决定融合什么
//   - Gemm:基础矩阵乘
//   - Epilogue:尾部操作(bias add / activation / elementwise)

template <
    typename GemmShape,         // <128, 128, 32> — tile 大小
    typename EpilogueOp,        // cutlass::epilogue::thread::ReLU
    typename ElementA,          // float16
    typename ElementB,          // float16
    typename ElementC,          // float32(输出类型)
    typename ElementBias        // float32(bias 类型)
>
class GemmWithFusion {
public:
    using Epilogue = EpilogueOp;

    // kernel 主函数
    CUTLASS_DEVICE
    void operator()(
        ElementC* ptr_C, int stride_C,
        ElementA* ptr_A, int stride_A,
        ElementB* ptr_B, int stride_B,
        ElementBias* ptr_bias  // bias 指针(融合用)
    ) {
        // === 阶段 1:分块加载 A 和 B ===
        TileIteratorA iteator_A(ptr_A, stride_A);
        TileIteratorB iteator_B(ptr_B, stride_B);

        // Tile 在寄存器/SMEM 中
        Tile<GemmShape, ElementA> tile_A;
        Tile<GemmShape, ElementB> tile_B;

        iterator_A.Load(tile_A, blockIdx.x * 128, 0, threadIdx.x);
        iterator_B.Load(tile_B, 0, blockIdx.y * 128, threadIdx.x);

        // === 阶段 2:矩阵乘(MMA 指令)===
        // Ascend NPU 用 Cube 单元做矩阵乘
        // catlass 把 MMA 封装成模板——编译期选择指令
        using MmaOp = typename MmaPromotion<GemmShape, ElementA, ElementB, ElementC>::Op;
        MmaOp mma_op;

        Tile<GemmShape, ElementC> accum;  // 累加器
        mma_op(accum, tile_A, tile_B, accum);  // C = A × B + C

        // === 阶段 3:Epilogue(融合的尾部操作)===
        // 这是融合的核心——epilogue 在矩阵乘完成后立即执行
        // 不需要写回 HBM 再读出来

        Epilogue epilogue;

        // Step 3.1:加载 bias(如果融合了这个操作)
        if constexpr (Epilogue::kHasBias) {
            Tile<GemmShape, ElementBias> tile_bias;
            // bias 是 [N] 向量,广播到 [M, N] tile
            LoadBias<branch::RowBroadcast>(tile_bias, ptr_bias, blockIdx.y * 128);
            epilogue.AddBias(accum, tile_bias);
        }

        // Step 3.2:激活函数(ReLU / GELU / SiLU)
        if constexpr (Epilogue::kHasActivation) {
            epilogue.ApplyActivation(accum);  // ReLU: max(0, x)
        }

        // Step 3.3:写回(这是唯一一次 HBM 写)
        TileIteratorC iteator_C(ptr_C, stride_C);
        iterator_C.Store(accum, blockIdx.x * 128, blockIdx.y * 128, 0);
    }
};

融合的关键:if constexpr 是 C++17 的特性——编译期 if。如果 Epilogue::kHasBias == false,整段 bias 加载代码在编译期被删除,最终二进制里不存在。这是"零成本"的另一层含义:没用到的融合组件不占代码空间。

编译期计算:Tensor Core 指令选择

Ascend NPU 的 Cube 单元支持多种矩阵乘指令:MMA_F16(float16 输入,float32 累加)、MMA_BF16(bfloat16)、MMA_TF32(TensorFloat32,仅 Ampere+)。catlass 用模板特化在编译期选择。

// catlass/include/cutlass/arch/mma.h

// 基础 MMA 操作描述(编译期常量)
template <typename Shape, typename ElementA, typename ElementB, typename ElementC>
struct MmaPromotion;

// === 特化 1:float16 × float16 → float32(最常用)===
template <int M, int N, int K>
struct MmaPromotion<MatrixShape<M, N, K>, float16, float16, float32> {
    using Op = MmaSychronized<
        MatrixShape<16, 8, 16>,  // Tensor Core 的 warp-level tile 大小
        float16,                  // A 类型
        float16,                  // B 类型
        float32,                  // C 类型(累加器)
        layout::RowMajor         // C 的布局
    >;

    // 一条 MMA 指令处理 16×8×16 的 tile
    // 256 个 thread(一个 warp)= 16 warps × 16 threads
    // 每个 thread 负责 1 个输出元素 → 16 warps × 16 threads = 256 elements
};

// === 特化 2:bfloat16 × bfloat16 → float32 ===
template <int M, int N, int K>
struct MmaPromotion<MatrixShape<M, N, K>, bfloat16, bfloat16, float32> {
    using Op = MmaSychronized<
        MatrixShape<16, 8, 16>,
        bfloat16,
        bfloat16,
        float32,
        layout::RowMajor
    >;
};

// === 特化 3:TensorFloat32(TF32,Ampere+ 专用)===
template <int M, int N, int K>
struct MmaPromotion<MatrixShape<M, N, K>, tfloat32, tfloat32, float32> {
    using Op = MmaSychronized<
        MatrixShape<16, 8, 8>,   // TF32 的 K 维度只有 8(精度降低)
        tfloat32,
        tfloat32,
        float32,
        layout::RowMajor
    >;
};

// === 用法:自动选择 ===
template <typename Shape, typename ElementA, typename ElementB, typename ElementC>
CUTLASS_DEVICE
void GemmKernel(ElementC* C, ElementA* A, ElementB* B, int M, int N, int K) {
    using Mma = typename MmaPromotion<Shape, ElementA, ElementB, ElementC>::Op;
    Mma mma;

    // 编译期根据 ElementA/B/C 自动选择 MMA 指令
    // 不需要 runtime if—编译器帮你选
    Tile<Shape, ElementC> accum;
    mma(accum, A, B, accum);
    // ...
}

Ascend NPU 没有 TF32(这是 NVIDIA 的专用格式),但 catlass 保持了和 CUTLASS 相同的接口——方便从 NVIDIA 迁移代码。

分层模板架构

catlass 的模板不是一团乱——分了 4 层,每层负责不同的抽象:

第 1 层:Kernel 层(最顶层)
  ↓ 决定:GEMM / Conv / Transform
  ↓ 决定:融合策略(epilogue 操作)
  GemmWithFusion<Shape, Epilogue>

第 2 层:Tile 层
  ↓ 决定:tile 大小(128×128×32 / 256×128×64)
  ↓ 决定:数据布局(RowMajor / ColumnMajor)
  TileIteratorA / TileIteratorB

第 3 层:Warp 层
  ↓ 决定:Warp 内部的 MMA 指令映射
  ↓ 决定:寄存器分配
  MmaSychronized<WarpShape, ElementA, ...>

第 4 层:指令层(最底层)
  ↓ 直接映射到 NPU 指令
  ↓ Cube 单元:MMA / MMA_SYNC
  ↓ Vector 单元:FMA / FFMA
  cute::asmsm("mma.sync...", ...)

每层只和上下层交互——改 tile 大小只需改第 2 层,不影响第 3/4 层。

实战:自定义融合算子(GEMM + SiLU)

SiLU(Sigmoid Linear Unit)= x × sigmoid(x),LLaMA 的激活函数。融合到 GEMM 的 epilogue:

// catlass/examples/gemm_silu/gemm_silu.cu

// 步骤 1:定义 Epilogue 操作(SiLU)
namespace cutlass {
namespace epilogue {
namespace thread {

class SiLU {
public:
    // Epilogue 操作必须实现 `operator()`(对 tile 的每个元素应用)
    template <typename Element>
    CUTLASS_DEVICE
    Element operator()(Element x) const {
        // SiLU(x) = x × sigmoid(x)
        // sigmoid(x) = 1 / (1 + exp(-x))
        float x_f = float(x);
        float sigmoid = 1.0f / (1.0f + expf(-x_f));
        return Element(x_f * sigmoid);
    }

    // 融合标识:编译期常量
    static constexpr bool kHasBias = false;
    static constexpr bool kHasActivation = true;
    static constexpr bool kHasSilU = true;  // 新增:SiLU 标识
};

}  // namespace thread
}  // namespace epilogue
}  // namespace cutlass

// 步骤 2:用自定义 Epilogue 实例化 GEMM
using GemmShape = cutlass::MatrixShape<128, 128, 32>;

using GemmSiLU = cutlass::gemm::kernel::GemmWithFusion<
    GemmShape,
    cutlass::epilogue::thread::SiLU,  // 自定义 epilogue
    float16,                          // ElementA
    float16,                          // ElementB
    float32,                          // ElementC
    float32                           // ElementBias(未用)
>;

// 步骤 3:启动 kernel
void LaunchGemmSiLU(
    float32* C, float16* A, float16* B,
    int M, int N, int K
) {
    dim3 grid((M + 128 - 1) / 128, (N + 128 - 1) / 128);
    dim3 block(256);  // 一个 warpgroup = 256 threads

    GemmSiLU kernel;
    kernel<<<grid, block>>>(C, M, A, K, B, N, nullptr);
}

编译期展开后,最终生成的 PTX(并行线程执行汇编)大致是:

; 伪代码:展开后的融合 kernel
ld.shared.f16 %fA, [tile_A];    ; 从 shared memory 加载 A
ld.shared.f16 %fB, [tile_B];    ; 加载 B
mma.sync.aligned.m16n8k16.f16.f16.f32 {...};  ; 矩阵乘
add.f32 %fC, %fC, %fAB;       ; 累加到 C

; Epilogue:SiLU(融合在这,不写回 HBM)
ex2.approx.ftz.f32 %fneg, -%fC;  ; exp(-x) 近似
add.f32 %fdenom, 1.0, %fneg;  ; 1 + exp(-x)
div.approx.f32 %fSigmoid, 1.0, %fdenom;  ; sigmoid
mul.f32 %fSiLU, %fC, %fSigmoid;  ; x × sigmoid

st.global.f32 [%rdC], %fSiLU;    ; 唯一一次 HBM 写

对比非融合版本(3 个 kernel):

; 非融合版本
; Kernel 1: GEMM
mma.sync...; add.f32 %fC, ...;
st.global.f32 [%rdC], %fC;      ; ← 写 HBM(第一次)

; Kernel 2: AddBias(重新加载)
ld.global.f32 %fC, [%rdC];      ; ← 读 HBM(第一次)
add.f32 %fC_bias, %fC, %fBias;
st.global.f32 [%rdC], %fC_bias; ; ← 写 HBM(第二次)

; Kernel 3: SiLU(重新加载)
ld.global.f32 %fC, [%rdC];      ; ← 读 HBM(第二次)
ex2.approx...; div...; mul...;
st.global.f32 [%rdC], %fSiLU;    ; ← 写 HBM(第三次)

三次 HBM 往返 vs 一次——延迟差 3×(假设 HBM 带宽 900GB/s)。

踩坑一:模板递归深度超过编译限制

catlass 大量使用模板递归(不是 runtime 递归)展开循环。编译期递归深度默认限制是 256——但某些大 tile(512×512×64)的递归深度会超过。

// ❌ 模板递归深度 1024 → 编译错误:recursive template instantiation depth exceeded
template <int N>
struct Factorial {
    static constexpr int value = N * Factorial<N - 1>::value;
};
template <>
struct Factorial<0> { static constexpr int value = 1; };

int x = Factorial<1024>::value;  // 编译错误!

// ✅ 用 constexpr 函数替代(C++14+,无递归深度限制)
constexpr int Factorial(int N) {
    int result = 1;
    for (int i = 1; i <= N; ++i) result *= i;
    return result;
}

constexpr int x = Factorial(1024);  // OK(编译期计算)

catlass 自己的代码用了大量模板递归——如果遇到编译错误,在 g++ 下加 -ftemplate-depth=1024,在 clang 下加 -ftemplate-depth=1024

踩坑二:Auto-tuning 搜索空间爆炸

catlass 的 tile 大小、VectorWidth、epilogue 融合组合是 15 维参数空间。暴力搜索(每个组合跑一遍 benchmark)需要 2^15 = 32768 次编译+运行——不现实。

# ❌ 暴力搜索:32768 种组合,每种编译 30 秒 → 273 小时
for tile_m in [64, 128, 256]:
    for tile_n in [64, 128, 256]:
        for tile_k in [16, 32, 64]:
            for vec_width in [4, 8, 16]:
                for epilogue in [None, ReLU, GELU, SiLU]:
                    CompileAndBenchmark(...)

# ✅ 分层搜索:先定 tile 大小,再调 epilogue
# 第 1 步:固定 epilogue=None,搜索最优 tile
best_tile = SearchTileSize(epilogue=None)  # 只搜索 3×3×3 = 27 种

# 第 2 步:用 best_tile,搜索最优 epilogue 融合
best_epilogue = SearchEpilogue(tile=best_tile)  # 只搜索 4 种

# 总时间:27 + 4 = 31 次编译 → 15 分钟

分层搜索的理论依据:tile 大小对性能的影响是数量级(共享内存占用、寄存器压力),epilogue 融合的影响是百分比(省 1-2 次 HBM 往返)。先调数量级,再调百分比。

踩坑三:模板报错信息完全不可读

模板元编程的报错信息(尤其是类型不匹配)可以达到 500 行——因为编译器要展开所有模板实例化路径才报错。

error: no matching function for call to 'Load'
note: candidate: template<class T> void Load(Tile<Shape, Element>&, int, int, int) [with T = float16; Shape = MatrixShape<128, 128, 32>; Element = float32]
note:   mismatched types 'Element' (float32) vs argument type 'float16 [128]'

这个报错的意思是:Load(tile_float32, ..., ptr_float16)——tile 是 float32,但指针是 float16。

// ❌ 报错 500 行
Tile<Shape128x128x32, float32> tile;  // accum 是 float32
float16* ptr = ...;
iterator.Load(tile, 0, 0, 0);  // ptr 是 float16 → 类型不匹配

// ✅ 用静态断言(static_assert)在编译期给出可读错误
template <typename Element, typename AccessType>
CUTLASS_DEVICE
void Load(Tile<Shape, Element>& tile, AccessType* ptr, ...) {
    static_assert(std::is_same<Element, AccessType>::value,
                  "Load: Element type must match AccessType");
    // ...
}

static_assert 会在类型不匹配时输出:error: static assertion failed: Load: Element type must match AccessType——一行,不是 500 行。

catlass 的代码大量使用 static_assert 和类型别名(using)来让报错可读。遇到模板报错,先搜 static_assert 再看类型推导。


catlass 的本质:用 C++ 模板元编程在编译期生成算子,运行时零开销。核心抽象 TileIterator 向量化加载 + 循环展开 → 最终代码和手写汇编一样高效。融合算子(GEMM + epilogue)靠 if constexpr 编译期消除未用代码——二进制里没有 dead code。踩坑集中在三方面:模板递归深度超限(加编译选项)、搜索空间爆炸(分层搜索)、报错不可读(用 static_assert)。

Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐