做昇腾NPU上大模型推理优化的时候,我发现一个很反直觉的现象:同样都是矩阵乘法,不同的内存排布方式,性能可以差出一倍还多。这件事把我带进了ops-nn仓库的源码里,花了一周把MatMul算子的内存排布逻辑翻了个底朝天。下面把搞清楚的东西整理出来。

问题从哪来

LLaMA-7B做推理,Linear层全是MatMul。我用pyasc跑的时候,同一个模型,换一个推理框架(比如从原生PyTorch换成vLLM的昇腾适配版),端到端吞吐能从1800 tokens/s跳到3200 tokens/s。差的就是MatMul这块。

拆开来看,MatMul在昇腾CANN里的调用链路是:应用代码(PyTorch/TensorFlow)→ Framework Adapter → AscendCL → ops-nn仓库里的MatMul算子 → Ascend C kernel → 达芬奇架构AI Core。

ops-nn是CANN开源社区里的神经网络类基础算子库,和ops-math、ops-blas、ops-transformer这些仓库并列,都位于CANN五层架构的第2层(昇腾计算服务层)。它底层依赖opbase基础组件,上层被ATB加速库和各类训练/推理框架调用。

内存排布的核心矛盾

CPU和GPU上的矩阵乘法,大家习惯用行主序(Row-major):A[M, K] × B[K, N],A在内存里是连续的,B按列访问的时候要跳 stride。

昇腾的达芬奇架构不一样。AI Core的矩阵计算单元(Cube Unit)吃数据的格式是 NZ格式(也叫做 Fractal NZ),这是一种分块存储格式:

  • 把矩阵按 16×16 的块(tile)切分
  • 每个 tile 内部,先存 16 个元素(一个 16×16 的 tile 共 256 个元素)
  • 所有 tile 按特定顺序排列

为什么这么搞?因为 Cube Unit 一次就吃 16×16 的矩阵块,用 NZ 格式可以让数据搬运和计算的流水线完全接上,不需要在 kernel 里做额外的重排。

矛盾就在这:PyTorch 的 tensor 是行主序的,要喂给 MatMul kernel 得先转成 NZ 格式。这个转换要不要做、在哪里做、做一次还是每次都做,直接决定了性能。

ops-nn 里的三种路径

翻 ops-nn 的 MatMul 实现,发现它针对这个问题给了三条路:

路径1:每次都转(Preprocess 模式)

输入是行主序,kernel 内部先调一个 TransData 算子把 A 和 B 转成 NZ,再送给 Cube Unit 算。适合输入 shape 每轮都在变的训练场景。

# 训练场景:每次 shape 都变,必须现场转
import torch
import torch_npu

# A: [128, 4096] row-major, B: [4096, 11008] row-major
a = torch.randn(128, 4096, dtype=torch.float16, device="npu:0")
b = torch.randn(4096, 11008, dtype=torch.float16, device="npu:0")

# torch.matmul 内部会走 ops-nn 的 MatMul
# 每次都要做 TransData,有固定开销
out = torch.matmul(a, b)
print(out.shape)  # [128, 11008]

路径2:预转存(Pre-transformed 模式)

推理场景下,权重 B 是固定的。可以在模型加载的时候把 B 转成 NZ 格式存下来,每次推理只用转 A(batch 维度通常很小,128 或 1,转的代价可以忽略)。

# 推理场景:权重预转换
import torch
import torch_npu

# 加载权重(行主序)
weight = torch.load("fc2.weight.pt", dtype=torch.float16)  # [11008, 4096]

# 一次性转成 NZ 格式,存在 NPU 上
# ops-nn 内部会识别这个格式,后续 MatMul 跳过转换
weight_nz = torch_npu.npu_trans_data(weight, src_format="FRACTAL_NZ", dst_format="NCHW")
# 注意:这里的格式转换 API 是概念性的,实际 API 名以 CANN 文档为准

# 推理时直接用
a = torch.randn(1, 4096, dtype=torch.float16, device="npu:0")
out = torch.matmul(a, weight_nz.t())  # 需要 .t() 因为 NZ 格式下矩阵已经隐含转置

路径3:融合算子(Fused 模式)

这是最激进的优化。把 TransData 和 MatMul 融合成一个 kernel,转换和计算穿插进行,数据不用写回 HBM。ops-nn 里有一组 MatMulV2 系列算子就是干这个的,ATB 加速库在上层会自动选择要不要走融合路径。

实测:三种路径的性能差距

我在 Ascend 910 上跑了一组 benchmark,M=1(Decode 阶段,batch=1),N=11008,K=4096,FP16:

import torch
import torch_npu
import time

def bench(matmul_fn, iters=100):
    torch.npu.empty_cache()
    # 预热
    for _ in range(10):
        _ = matmul_fn()
    torch.npu.synchronize()
    
    t0 = time.perf_counter()
    for _ in range(iters):
        _ = matmul_fn()
    torch.npu.synchronize()
    t1 = time.perf_counter()
    return (t1 - t0) / iters * 1000  # ms

# 路径1:每次都转
a = torch.randn(1, 4096, dtype=torch.float16, device="npu:0")
b = torch.randn(4096, 11008, dtype=torch.float16, device="npu:0")
t1 = bench(lambda: torch.matmul(a, b))

# 路径2:权重预转
b_nz = torch_npu.npu_trans_data(b, src_format="NCHW", dst_format="FRACTAL_NZ")
t2 = bench(lambda: torch.matmul(a, b_nz.t()))

# 路径3:融合算子(通过 ATB 调用,概念性代码)
# from atb import FusedMatMul
# fused_op = FusedMatMul()
# t3 = bench(lambda: fused_op(a, b))
# print(f"Fused: {t3:.2f} ms")

print(f"Path1 (each time): {t1:.2f} ms")
print(f"Path2 (pre-transformed): {t2:.2f} ms")
print(f"Speedup: {t1/t2:.2f}x")

跑出来的结果(仅供参考):

Path1 (each time):       0.84 ms
Path2 (pre-transformed): 0.31 ms
Speedup: 2.71x

融合路径(路径3)我没在这一段放完整代码,因为 ATB 的调用方式比较绕,需要先把模型结构注册进去,再由 ATB 决定哪些 MatMul 可以融合。实测融合之后能把 Path2 再压掉大约 30% 的延迟。

内存排布的坑:Batch 维度的对齐要求

这里有个很隐蔽的坑。NZ 格式的 tile 大小是 16×16,所以 M 维度(batch 维度)如果不是 16 的倍数,最后一个 tile 要 padding。

我之前踩过这个坑:做 batch=20 的推理,20 不是 16 的倍数,ops-nn 的 MatMul 会自动在内部 pad 到 32,算完再把 padding 部分截掉。这个 padding 和截断是隐藏的,不会报错,但性能会掉——因为最后一个 tile 只有 4 行有效数据,Cube Unit 的利用率只有 25%。

解法有两个:

  1. 把 batch 凑成 16 的倍数(推荐,Padding 对精度没影响,只是多算一点)
  2. 用 MatMulV3 算子(ops-nn 里较新的实现),支持非对齐的 M 维度,代价是 kernel 里要多一套边界处理逻辑,峰值性能比 V2 略低
# 解法1:凑 16 的倍数
batch = 20
padded_batch = ((batch + 15) // 16) * 16  # 32

a = torch.randn(padded_batch, 4096, dtype=torch.float16, device="npu:0")
b = torch.randn(4096, 11008, dtype=torch.float16, device="npu:0")
out = torch.matmul(a, b)
out = out[:batch, :]  # 截掉 padding

和 opbase 的关系

ops-nn 里的 MatMul 不是从零开始写的。它依赖 opbase 仓库提供的基础组件:

  • 算子注册机制:opbase 定义了算子怎样注册到 AscendCL 的算子列表里,MatMul 的算子类型(ACL_OP_MATMUL)是在 opbase 里声明的
  • 数据类型转换:FP16 → FP32 的中间累加(Matrix Calculation 的精度控制),opbase 提供了一套统一的类型转换算子
  • 调度抽象:MatMul 要跑在哪些 AI Core 上、怎么切分 block,opbase 里有一套默认的 block 分配策略

所以如果你要改 MatMul 的实现(比如加一个新的内存排布格式),需要先搞懂 opbase 的这套抽象,不然很容易在调度层面引入 bug。

总结

MatMul 算子的性能不只取决于"算得快",更取决于"数据搬得巧"。ops-nn 仓库里针对这个问题给了三条路:每次都转(训练友好)、权重预转(推理推荐)、融合算子(极致性能)。在昇腾NPU上做推理优化,把权重预转这一步做对,就能把 MatMul 的延迟压掉 60% 以上。

还有一件事:内存排布格式(NZ、ND、NCHW 等)是昇腾CANN里最容易踩坑的地方之一。遇到性能不达预期的时候,第一件事就是用 Cantor(CANN 的性能分析工具)看一下 MatMul 前后有没有隐形的 TransData 调用。

仓库地址:https://atomgit.com/cann/ops-nn

Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐