前言

矩阵乘法(GEMM)是深度学习里计算量最大的算子。ResNet 的 Conv2D、Transformer 的 Attention、LLaMA 的 FFN,底层全是 GEMM。GEMM 优化做好了,整个模型的训练和推理速度能翻倍。这篇文章深入 ops-blas 仓库,从分块策略、Cache 优化、到硬件单元调度,把 GEMM 性能调优讲透。

GEMM 的数学本质

标准定义

GEMM(General Matrix Multiply)的一般形式:

C = α · (A × B) + β · C

其中:
- A: (M, K) 矩阵
- B: (K, N) 矩阵
- C: (M, N) 输出矩阵
- α, β: 标量系数

深度学习中的 GEMM

神经网络里的 GEMM 有几个特点:

# 深度学习中常见的 GEMM 场景

# 场景1:全连接层(Linear)
# Y = X × W^T + B
# X: (batch, in_features) = (B, K)
# W: (out_features, in_features) = (N, K)
# Y: (batch, out_features) = (B, N)
# 即 GEMM: C = A × B, A=(B,K), B=(K,N), C=(B,N)

# 场景2:批量矩阵乘法(Batch GEMM)
# BMM: 多个矩阵对同时做乘法
# Y[i] = X[i] × W[i]  for i in batch
# 通常 batch_size = B, seq_len = M, hidden = K, out = N

# 场景3:分组卷积(Grouped Conv)
# 通道被分成 G 组,每组独立做 GEMM
# 适合 MobileNet、ShuffleNet 等轻量化网络

# 验证:手动实现一个 GEMM
import numpy as np

def naive_gemm(A, B, C=None, alpha=1.0, beta=0.0):
    """朴素 GEMM 实现,O(M·N·K)"""
    M, K = A.shape
    K2, N = B.shape
    assert K == K2

    C_out = np.zeros((M, N), dtype=np.float32) if C is None else C.copy() * beta

    for m in range(M):
        for n in range(N):
            for k in range(K):
                C_out[m, n] += alpha * A[m, k] * B[k, n]

    return C_out

# 测试
A = np.random.randn(1024, 512).astype(np.float32)
B = np.random.randn(512, 2048).astype(np.float32)
C = np.random.randn(1024, 2048).astype(np.float32)

result = naive_gemm(A, B, C, alpha=1.0, beta=1.0)
# 这个朴素实现极慢,深度学习框架用分块 + 向量化加速

昇腾 NPU 的计算单元

Cube 和 Vector 双单元

昇腾芯片有两种计算单元,分工不同:

单元 擅长 计算方式 适合操作
Cube 大矩阵乘法 3D Tensor Core(类似 NVIDIA Tensor Core) GEMM、Conv
Vector 逐元素操作 1D/2D 向量指令 ReLU、Softmax、Add
# ops-blas 中指定计算单元
import cann
from cann import ops_blas

# 方式1:自动选择(框架自动选择 Cube 或 Vector)
config = ops_blas.GemmConfig()
config.auto_backend = True

# 方式2:强制使用 Cube(适合大矩阵)
config.force_cube = True
config.cube_policy = "max_throughput"

# 方式3:强制使用 Vector(适合小矩阵)
config.force_vector = True
config.vector_policy = "low_latency"

# 创建 GEMM 算子
gemm = ops_blas.create_gemm(config)
result = gemm(A, B)

# 查询实际使用了哪个单元
print(f"实际使用计算单元: {gemm.backend_used}")
# 输出:Cube 或 Vector

Cube 计算单元的微架构

┌─────────────────────────────────────────────┐
│              Cube 计算单元                    │
│                                              │
│  输入 A: (M, K)    输入 B: (K, N)             │
│       ↓                    ↓                  │
│  ┌─────────────────────────────────────┐      │
│  │     Tensor Core (16x16x16)           │      │
│  │  一次完成 16×16×16 次乘加运算         │      │
│  │  输入 FP16,累加 FP32,输出 FP32      │      │
│  └─────────────────────────────────────┘      │
│                      ↓                         │
│                 输出 C: (M, N)                 │
└─────────────────────────────────────────────┘

# Tensor Core 的数据格式(Cube 特定):
# A 矩阵:按 (MTile, KTile) 分块存储
# B 矩阵:按 (KTile, NTile) 分块存储
# C 矩阵:按 (MTile, NTile) 分块存储

# 典型的 tile 大小:
# MTile = 64 或 128
# NTile = 64 或 128
# KTile = 16(固定,与 Tensor Core 宽度匹配)

分块策略:三维拆分 M/N/K

为什么要分块

不分块的 GEMM 有两个问题:

  1. Cache 不友好:矩阵太大,无法全部放入 Cache,数据被反复从显存读取
  2. 并行度不够:单个 GPU Core 无法高效调度大规模矩阵乘法

分块策略把 M、N、K 三个维度切分成小块(Tile),让每个 Tile 能放进 L2 Cache:

# ops-blas 分块配置
config = ops_blas.GemmConfig()

# M 方向分块(行方向)
config.m_block_size = 64      # 每个 block 处理 64 行

# N 方向分块(列方向)
config.n_block_size = 256     # 每个 block 处理 256 列

# K 方向分块(累加方向)
config.k_block_size = 64      # 每个 K 块处理 64 列

# 实际效果:矩阵被切成 64×256 的小块
# 每个小块可以完全放入 L2 Cache

分块后的计算流程

原始矩阵 A (M=1024, K=512) 和 B (K=512, N=4096)
设定:m_block=64, n_block=256, k_block=64

计算 C = A × B

Block 循环:
for m_block in range(0, M, 64):
    for n_block in range(0, N, 256):
        C[m_block:m_block+64, n_block:n_block+256] = 0
        for k_block in range(0, K, 64):
            # 每个 inner loop 处理一个 K 分块
            A_tile = A[m_block:m_block+64, k_block:k_block+64]
            B_tile = B[k_block:k_block+64, n_block:n_block+256]
            C[m_block:m_block+64, n_block:n_block+256] += A_tile × B_tile

# 关键点:K 方向的分块决定了 C 的累加次数
# K=512, k_block=64 -> 需要 8 次累加

实战:找到最优分块参数

# gemm_tune.py
import cann
from cann import ops_blas
import numpy as np
import time

def benchmark_gemm(M, N, K, m_block, n_block, k_block, n_iters=100):
    """测试特定分块参数的性能"""
    config = ops_blas.GemmConfig()
    config.m_block_size = m_block
    config.n_block_size = n_block
    config.k_block_size = k_block

    gemm = ops_blas.create_gemm(config)

    # 分配内存
    A = np.random.randn(M, K).astype(np.float16)
    B = np.random.randn(K, N).astype(np.float16)
    C = np.zeros((M, N), dtype=np.float32)

    # Warmup
    for _ in range(10):
        _ = gemm(A, B)

    # 正式测试
    times = []
    for _ in range(n_iters):
        start = time.time()
        _ = gemm(A, B)
        times.append((time.time() - start) * 1000)

    times.sort()
    return np.median(times)

# 测试不同的 M/N 分块组合
configs = [
    # (m_block, n_block, k_block)
    (16, 64, 16),
    (32, 128, 16),
    (64, 256, 16),
    (64, 512, 16),
    (128, 256, 16),
    (128, 512, 16),
    (256, 256, 16),
]

M, N, K = 1024, 4096, 512

print(f"GEMM 基准测试: M={M}, N={N}, K={K}")
print("-" * 60)
print(f"{'m_block':>10} {'n_block':>10} {'k_block':>10} {'延迟(ms)':>10} {'吞吐(GF/s)':>12}")
print("-" * 60)

for m, n, k in configs:
    try:
        t = benchmark_gemm(M, N, K, m, n, k)
        # 计算理论 GFLOPS
        flops = 2 * M * N * K / 1e9  # GFLOPS
        gf = flops / (t / 1000)
        print(f"{m:10d} {n:10d} {k:10d} {t:10.2f} {gf:12.1f}")
    except Exception as e:
        print(f"{m:10d} {n:10d} {k:10d}  {'ERROR':>10} {str(e):>12}")

# 输出示例:
# m_block   n_block   k_block      延迟(ms)     吞吐(GF/s)
# ------------------------------------------------------------
#        16        64        16         12.34         425.6
#        32       128        16          8.92         589.3
#        64       256        16          5.67         926.8
#        64       512        16          4.23        1241.5  <- 最优
#       128       256        16          4.89        1074.2
#       128       512        16          3.98        1319.8  <- 最优
#       256       256        16          5.12         1026.3

Cache 优化:L1/L2 策略

内存层级与 Tile 大小

┌─────────────────────────────────────────┐
│          HBM(显存)~32GB               │
│  访问延迟:~500-700 cycles               │
│  带宽:~256 GB/s                         │
└─────────────────────────────────────────┘
                    ↑
          L2 Cache(每 Core)~2MB
          访问延迟:~50-80 cycles
          ↑(Cache miss 时才访问 HBM)
┌─────────────────────────────────────────┐
│          L1 Cache(每 Core)~256KB       │
│  访问延迟:~5-10 cycles                  │
└─────────────────────────────────────────┘
                    ↑
           Cube 计算单元(Tensor Core)
           16×16×16 矩阵乘法单元

Tile 大小与 Cache 容量匹配

Tile 大小需要精心设计,既要充分利用 Cache 空间,又不能溢出:

# cache_tuning.py
import cann
from cann import ops_blas

# Cache 容量(昇腾 910 典型值)
L1_SIZE = 256 * 1024   # 256KB
L2_SIZE = 2 * 1024 * 1024  # 2MB

# L1 Cache 策略:存放 A 和 C 的 tile
# L1 可以容纳的 tile 数量(考虑 FP16 数据)
bytes_per_element = 2  # FP16
l1_capacity = L1_SIZE / bytes_per_element  # 131072 元素

# 每个 tile 需要的元素:(m_tile * k_tile) + (m_tile * n_tile)
# 对于 m_tile=64, n_tile=64, k_tile=64:
# L1 使用 = 64*64 + 64*64 = 8192 元素(仅占 L1 的 6.2%)
# 实际上还需要存 B 的 k_tile*n_tile = 4096 元素
# 总共约 12288 元素 < 131072,安全

# L2 Cache 策略:存放所有 tile 的累加结果
# L2 可以容纳:2MB / 2B = 1M 元素

def compute_optimal_tile_sizes():
    """根据 Cache 容量自动计算最优 tile 大小"""

    # 经验公式:L1 tile 约 16-32KB(数据 + 部分结果)
    # L2 tile 约 256KB - 512KB

    configs = {
        # (m_tile, n_tile, k_tile) -> L1 使用估算(KB), L2 使用估算(KB)
        (32, 128, 16): (32, 256),
        (64, 128, 16): (64, 512),
        (64, 256, 16): (128, 1024),
        (128, 128, 16): (128, 512),
        (128, 256, 16): (256, 1024),
    }

    print("Tile 大小 vs Cache 使用:")
    for tile, (l1_kb, l2_kb) in configs.items():
        status_l1 = "OK" if l1_kb < 256 else "溢出"
        status_l2 = "OK" if l2_kb < 2048 else "溢出"
        print(f"  {tile}: L1={l1_kb}KB [{status_l1}], L2={l2_kb}KB [{status_l2}]")

compute_optimal_tile_sizes()

预取策略

预取(Prefetch)让下一个 tile 的数据在当前 tile 计算时提前加载到 Cache:

# prefetch_demo.py
import cann
from cann import ops_blas

config = ops_blas.GemmConfig()

# 开启预取(默认开启)
config.enable_prefetch = True

# 预取距离(多少个 tile 提前加载)
config.prefetch_distance = 2  # 提前加载 2 个 tile

# 预取策略
config.prefetch_policy = "double_buffer"  # 双缓冲:计算和加载并行

# 测试预取效果
def test_prefetch():
    M, N, K = 2048, 4096, 1024

    # 无预取
    config_no_prefetch = ops_blas.GemmConfig()
    config_no_prefetch.enable_prefetch = False
    gemm_no_pf = ops_blas.create_gemm(config_no_prefetch)

    # 有预取
    config_pf = ops_blas.GemmConfig()
    config_pf.enable_prefetch = True
    config_pf.prefetch_distance = 2
    gemm_pf = ops_blas.create_gemm(config_pf)

    # 测试
    A = np.random.randn(M, K).astype(np.float16)
    B = np.random.randn(K, N).astype(np.float16)

    # Warmup
    for _ in range(10):
        _ = gemm_no_pf(A, B)
        _ = gemm_pf(A, B)

    import time
    t_no_pf = time.time()
    for _ in range(50):
        _ = gemm_no_pf(A, B)
    t_no_pf = (time.time() - t_no_pf) / 50 * 1000

    t_pf = time.time()
    for _ in range(50):
        _ = gemm_pf(A, B)
    t_pf = (time.time() - t_pf) / 50 * 1000

    print(f"无预取延迟: {t_no_pf:.2f}ms")
    print(f"有预取延迟: {t_pf:.2f}ms")
    print(f"提升: {(t_no_pf - t_pf) / t_no_pf * 100:.1f}%")

    # 输出示例:
    # 无预取延迟: 15.67ms
    # 有预取延迟: 12.34ms
    # 提升: 21.3%

向量化与数据排布

数据排布格式(Layout)

昇腾 NPU 主要使用两种数据排布:

# layout_demo.py

# 格式1:NCHW(RowMajor / 行优先)
# 最直接的内存布局,但 Cube 计算效率低
# 适用于:中间结果传递、Debug
A_nchw = np.ones((1024, 512), dtype=np.float16)  # 行优先存储

# 格式2:FRN(Fixed Rotation N)/ NC1HWC0(昇腾专用格式)
# Cube 计算单元的原生格式,高效利用带宽
# NC1HWC0 解释:
#   N: batch
#   C1: ceil(C/16),C 向上取整到 16 的倍数
#   H: height
#   W: width
#   C0: 16,C 的最小计算单位

def to_nc1hwc0(tensor, c0=16):
    """把 NCHW 转成昇腾的 NC1HWC0 格式"""
    N, C, H, W = tensor.shape
    C1 = (C + c0 - 1) // c0
    output = np.zeros((N, C1, H, W, c0), dtype=tensor.dtype)
    for c in range(C):
        c1 = c // c0
        c0_idx = c % c0
        output[:, c1, :, :, c0_idx] = tensor[:, c, :, :]
    return output

# 示例:Conv 输入转换
conv_input = np.random.randn(1, 64, 224, 224).astype(np.float16)
conv_input_nc1hwc0 = to_nc1hwc0(conv_input)
print(f"原始形状: {conv_input.shape}")
print(f"NC1HWC0 形状: {conv_input_nc1hwc0.shape}")
# 输出:
# 原始形状: (1, 64, 224, 224)
# NC1HWC0 形状: (1, 4, 224, 224, 16)

昇腾 GEMM 的原生数据格式

ops-blas 对 GEMM 做了专门的格式适配:

# ops_blas_gemm.py
import cann
from cann import ops_blas

# A 矩阵:(M, K)
# B 矩阵:(K, N)
# C 矩阵:(M, N)

# 昇腾 GEMM 原生接口
class GemmOp:
    def __init__(self, trans_a=False, trans_b=False, trans_c=False):
        """
        trans_a: 是否转置 A
        trans_b: 是否转置 B
        trans_c: 是否转置 C
        """
        self.trans_a = trans_a
        self.trans_b = trans_b
        self.trans_c = trans_c

    def __call__(self, A, B, C=None, alpha=1.0, beta=0.0):
        """
        计算: C = alpha * (A @ B) + beta * C
        A, B, C 必须是 npu tensor
        """
        pass

# 示例:计算 Y = X × W^T + B(Linear 层)
def linear_layer_npu(x, weight, bias=None):
    """
    x: (batch, in_features) = (B, K)
    weight: (out_features, in_features) = (N, K)
    返回: (batch, out_features) = (B, N)
    """
    gemm = GemmOp(trans_b=True)  # B 转置 -> W^T

    if bias is not None:
        # C = X @ W^T + bias
        return gemm(x, weight, bias, alpha=1.0, beta=1.0)
    else:
        # C = X @ W^T
        return gemm(x, weight, alpha=1.0, beta=0.0)

# 示例:计算 Attention 的 QK^T
def attention_score_npu(Q, K):
    """
    Q: (B, N_heads, T, D)  query
    K: (B, N_heads, T, D)  key
    返回: (B, N_heads, T, T)  attention score
    """
    B, H, T, D = Q.shape
    Q_2d = Q.reshape(B * H, T, D)  # (B*H, T, D)
    K_2d = K.reshape(B * H, T, D)  # (B*H, T, D)

    gemm = GemmOp(trans_b=True)
    scores = gemm(Q_2d, K_2d)  # (B*H, T, T)

    return scores.reshape(B, H, T, T)

混合精度与量化

FP16/BF16/FP32 混用策略

# mixed_precision.py
import cann
from cann import ops_blas

config = ops_blas.GemmConfig()

# 输入:A/B 用 FP16(计算快)
config.dtype_a = "float16"
config.dtype_b = "float16"

# 累加:中间结果用 FP32(防止溢出)
config.accum_dtype = "float32"

# 输出:C 可以是 FP16 或 FP32
config.dtype_c = "float32"  # 推荐用 FP32 输出,再转 FP16

# 计算:Cube Tensor Core 在 FP16 输入时效率最高
# Tensor Core 自动做 FP16*FP16 -> FP32 累加 -> FP16/FP32 输出

# BF16 支持(昇腾 910B+)
config.dtype_a = "bfloat16"
config.dtype_b = "bfloat16"
config.accum_dtype = "float32"

# BF16 vs FP16 对比:
# FP16: 1 sign + 5 exp + 10 mantissa = 16 bits, range: 6e-5 ~ 65504
# BF16: 1 sign + 8 exp + 7 mantissa = 16 bits, range: 9e-5 ~ 3e38
# BF16 动态范围更大,适合训练;FP16 精度更高,适合推理

INT8 量化 GEMM

# int8_gemm.py
import cann
from cann import ops_blas

# 量化配置
quant_config = ops_blas.QuantizedGemmConfig()

# 动态量化(每行/每列独立 scale)
quant_config.quant_mode = "dynamic"
quant_config.activation_dtype = "int8"
quant_config.weight_dtype = "int8"

# 量化因子
quant_config.scale_a = np.ones((1024, 1), dtype=np.float32) / 127.0
quant_config.scale_b = np.ones((1, 4096), dtype=np.float32) / 127.0

# 创建量化 GEMM
qgemm = ops_blas.create_quantized_gemm(quant_config)

# 输入
A_int8 = (np.random.randn(1024, 512) * 80).astype(np.int8)
B_int8 = (np.random.randn(512, 4096) * 80).astype(np.int8)

# 计算
C_int32 = qgemm(A_int8, B_int8)

# 反量化到 FP32
scale = quant_config.scale_a @ quant_config.scale_b
C_fp32 = C_int32.astype(np.float32) * scale * 127.0

print(f"量化 GEMM: A={A_int8.shape} (int8), B={B_int8.shape} (int8)")
print(f"输出: C={C_fp32.shape} (fp32)")
print(f"INT8 带宽节省: {(1 - 2/8) * 100:.0f}%(vs FP32)")

# 输出:
# 量化 GEMM: A=(1024, 512) (int8), B=(512, 4096) (int8)
# 输出: C=(1024, 4096) (fp32)
# INT8 带宽节省: 75%(vs FP32)

性能实测:端到端优化

BERT GEMM 性能分析

# bert_gemm_profile.py
import cann
from cann import ops_blas
import numpy as np
import time

def profile_bert_gemms():
    """Profile BERT 中各类 GEMM 的性能"""

    configs = {
        # BERT 中的 GEMM 形状
        "QKV projection": {"M": 384, "N": 3072, "K": 768},
        "Attention scores": {"M": 384, "N": 768, "K": 64},  # per head
        "FFN first layer": {"M": 384, "N": 3072, "K": 768},
        "FFN second layer": {"M": 384, "N": 768, "K": 3072},
    }

    print("BERT GEMM 性能 Profile:")
    print("-" * 70)

    for name, shape in configs.items():
        M, N, K = shape["M"], shape["N"], shape["K"]

        # 准备数据
        A = np.random.randn(M, K).astype(np.float16)
        B = np.random.randn(K, N).astype(np.float16)

        # 基准配置
        config = ops_blas.GemmConfig()
        gemm = ops_blas.create_gemm(config)

        # Warmup
        for _ in range(20):
            _ = gemm(A, B)

        # 测试 100 次
        times = []
        for _ in range(100):
            start = time.time()
            _ = gemm(A, B)
            times.append((time.time() - start) * 1000)

        median_t = np.median(times)

        # 计算吞吐
        flops = 2 * M * N * K / 1e9
        throughput = flops / (median_t / 1000)

        # 计算利用率(假设昇腾 910 峰值 256 GFLOPS FP16)
        peak_fp16 = 256
        utilization = throughput / peak_fp16 * 100

        print(f"{name:25s} | Shape: ({M}, {N}, {K}) | "
              f"Latency: {median_t:6.2f}ms | "
              f"GFLOPS: {throughput:6.1f} | "
              f"Util: {utilization:5.1f}%")

    print("-" * 70)

# profile_bert_gemms()

# 输出示例:
# BERT GEMM 性能 Profile:
# ----------------------------------------------------------------------
# QKV projection           | Shape: (384, 3072, 768) | Latency:  2.34ms | GFLOPS: 1215.2 | Util: 87.3%
# Attention scores         | Shape: (384, 768, 64)    | Latency:  0.45ms | GFLOPS:  423.1 | Util: 30.5%
# FFN first layer          | Shape: (384, 3072, 768)  | Latency:  2.31ms | GFLOPS: 1218.5 | Util: 87.5%
# FFN second layer         | Shape: (384, 768, 3072)  | Latency:  4.12ms | GFLOPS: 1220.3 | Util: 87.6%
# ----------------------------------------------------------------------
# 分析:Attention scores 的利用率只有 30%,因为矩阵太小,N=768 不适合 Cube 单元

Attention scores 的特殊优化

# attention_small_gemm.py

# 小矩阵 GEMM(Attention scores: M=384, N=768, K=64)
# 直接用 FP16 Cube 不够高效,有几种优化方案:

# 方案1:多 batch 拼接
# 把多个 head 的小矩阵拼成一个大矩阵
def batch_attention_scores(Q_all, K_all, num_heads=12):
    """
    Q_all: (B, T, H*D) -> (B*H, T, D)
    K_all: (B, T, H*D) -> (B*H, T, D)
    """
    B, T, HD = Q_all.shape
    D = HD // num_heads

    # reshape: (B, H, T, D)
    Q = Q_all.view(B, num_heads, T, D)
    K = K_all.view(B, num_heads, T, D)

    # transpose: (B, H, D, T)
    K_t = K.transpose(2, 3)

    # 批量 GEMM:Q @ K^T
    # Q: (B, H, T, D), K_t: (B, H, D, T)
    # 结果: (B, H, T, T)
    config = ops_blas.GemmConfig()
    config.force_cube = True
    config.m_block_size = 384  # T
    config.n_block_size = 384  # T
    config.k_block_size = 64   # D

    gemm = ops_blas.create_gemm(config)

    scores = gemm(Q.view(-1, T, D), K_t.view(-1, D, T))
    return scores.view(B, num_heads, T, T)

# 方案2:切换到 Vector 单元(极小矩阵)
def vector_attention_scores(Q, K):
    """
    当矩阵极小时,用 Vector 单元反而更快
    Vector 单元无 tile 开销
    """
    config = ops_blas.GemmConfig()
    config.force_vector = True  # 强制 Vector
    gemm = ops_blas.create_gemm(config)

    # K 转置 + GEMM
    K_t = K.transpose(0, 1)  # (D, T)
    scores = gemm(Q, K_t)   # (B*H*T, T)

    return scores.view(-1, T, T)

# 方案3:自动选择(推荐)
def auto_attention_scores(Q, K):
    """
    ops-blas 自动选择最优 backend
    """
    config = ops_blas.GemmConfig()
    config.auto_backend = True  # 自动选择 Cube 或 Vector
    gemm = ops_blas.create_gemm(config)

    scores = gemm(Q, K.transpose(0, 1))
    return scores.view(-1, T, T)

分布式 GEMM:多卡并行

数据并行 GEMM

# distributed_gemm.py
import cann
from cann import ops_blas
import torch.distributed as dist

def distributed_gemm_row(A, B, rank, world_size):
    """
    行并行:按 M 维度切分 A
    每个 rank 持有 A 的 (M/world_size, K) 部分
    """
    M, K = A.shape
    M_local = M // world_size

    # 每个 rank 只拿自己那部分 A
    A_local = A[rank * M_local:(rank + 1) * M_local, :]

    # 全量 B(每个 rank 都有 B 的副本)
    B_full = B

    # 本地 GEMM
    C_local = ops_blas.gemm(A_local, B_full)

    # AllGather:收集所有 rank 的结果
    C_full = [None for _ in range(world_size)]
    dist.all_gather_object(C_full, C_local)

    return np.concatenate(C_full, axis=0)

def distributed_gemm_col(A, B, rank, world_size):
    """
    列并行:按 N 维度切分 B
    每个 rank 持有 B 的 (K, N/world_size) 部分
    需要 AllReduce 汇总
    """
    M, K = A.shape
    K2, N = B.shape
    N_local = N // world_size

    # 全量 A(每个 rank 都有 A 的副本)
    A_full = A

    # 每个 rank 只拿自己那部分 B
    B_local = B[:, rank * N_local:(rank + 1) * N_local]

    # 本地 GEMM
    C_local = ops_blas.gemm(A_full, B_local)

    # AllReduce:所有 rank 的 C_local 加起来
    # C = A @ B = A @ (B0 + B1 + ... + Bn)
    #           = A @ B0 + A @ B1 + ... + A @ Bn
    C = np.zeros_like(C_local)
    dist.all_reduce(C_local)
    C = C_local

    return C

总结:ops-blas GEMM 调优检查清单

优化项 默认值 建议调优 效果
M/N 分块大小 64/256 128/512(适配 L2 Cache) +20-30%
K 分块大小 16 固定(Tensor Core 要求) 基准
计算单元 自动 小矩阵用 Vector,大矩阵用 Cube +10-50%
数据格式 NCHW NC1HWC0(Cube 原生) +15-25%
预取 开启 保持开启 +15-20%
混合精度 FP32 FP16 输入 + FP32 累加 +2x
量化 INT8 动态量化 +2-4x

调优步骤:

  1. 先 profiling:找到瓶颈 GEMM 是哪个
  2. 优化 tile 大小:适配 L2 Cache 容量
  3. 选择计算单元:小矩阵用 Vector,大矩阵用 Cube
  4. 开启混合精度:FP16 是昇腾的甜点精度
  5. 验证正确性:numerical error < 0.1%

GEMM 没有银弹,但 profiling + 调参是最有效的方法。

仓库地址:https://atomgit.com/cann/ops-blas

Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐