前言

多卡NPU在同一个服务器上,卡之间怎么通信?最直观的方式是通过PCIe——数据从卡A搬到内存,再从内存搬到卡B。PCIe 4.0 x16带宽32GB/s,看着不低,但双向往返一次要搬两次数据,实际有效带宽只有16GB/s。

更麻烦的是延迟:每次通信都要经过CPU内存控制器,PCIe事务层、数据链路层、物理层层层加码,单次延迟在微秒级别。

shmem是昇腾CANN的共享内存通信库,让同节点内的NPU卡直接读写同一块物理内存,绕过PCIe。实测下来,shmem比PCIe快10倍

共享内存 vs PCIe

PCIe通信路径:

NPU卡A显存 → PCIe → CPU内存 → PCIe → NPU卡B显存
     ↓                              ↓
   16GB/s                      16GB/s
   延迟 ~5μs                   延迟 ~5μs
   总延迟 ~10μs

共享内存通信路径:

NPU卡A显存 ←──→ 共享内存区域 ←──→ NPU卡B显存
              ↓
            200GB/s+ (HBM带宽)
            延迟 ~0.5μs

关键区别:共享内存区域对所有NPU卡可见,数据不需要经过CPU内存,卡A写完卡B直接读。

shmem核心API

API 功能 使用场景
shmem_init 初始化共享内存上下文 程序启动时调用
shmem_malloc 分配共享内存 创建通信缓冲区
shmem_put 写数据到远程卡的共享内存 推送数据
shmem_get 从远程卡的共享内存读数据 拉取数据
shmem_barrier 同步所有参与方 确保数据一致性
shmem_finalize 清理共享内存资源 程序结束时调用

代码实战:用shmem实现AllReduce

import shmem
import torch
import time

# ========== 第1步:初始化 ==========
# 获取当前卡的rank和总卡数
rank = shmem.my_pe()  # 当前卡的ID
n_pes = shmem.n_pes()  # 总卡数

shmem.init()

# ========== 第2步:分配共享内存 ==========
# 每块卡分配一块共享内存,其他卡可以访问
buffer_size = 1024 * 1024 * 4  # 4MB
local_buffer = shmem.malloc(buffer_size)  # 返回共享内存指针

# 把共享内存包装成PyTorch张量(零拷贝)
local_tensor = torch.from_buffer(
    local_buffer, 
    dtype=torch.float32,
    count=buffer_size // 4
).npu()

# 初始化数据
local_tensor.fill_(rank * 1.0)  # 卡0填0.0,卡1填1.0,...

# ========== 第3步:AllReduce(求和) ==========
def allreduce_shmem(tensor, op='sum'):
    """
    用shmem实现AllReduce:每个卡把自己的数据写到其他卡的共享内存,然后本地求和
    
    参数:
        tensor: 本地张量
        op: 归约操作,'sum'或'multi'
    """
    n_pes = shmem.n_pes()
    rank = shmem.my_pe()
    
    # 第1步:把自己的数据广播给所有其他卡
    for pe in range(n_pes):
        if pe != rank:
            # put: 把本地数据写到pe卡的共享内存
            shmem.put(
                dest=pe,           # 目标卡
                dest_idx=0,        # 目标偏移
                source=tensor.data_ptr(),  # 源地址
                nelems=tensor.numel()      # 元素个数
            )
    
    # 第2步:等待所有put完成
    shmem.barrier()
    
    # 第3步:从其他卡的共享内存拉取数据,本地求和
    temp_buffer = torch.empty_like(tensor)
    for pe in range(n_pes):
        if pe != rank:
            # get: 从pe卡的共享内存读数据
            shmem.get(
                dest=temp_buffer.data_ptr(),
                source=pe,
                source_idx=0,
                nelems=tensor.numel()
            )
            tensor += temp_buffer
    
    shmem.barrier()
    return tensor

# ========== 第4步:性能测试 ==========
# 测试不同数据量的AllReduce性能
sizes = [1024, 10240, 102400, 1024000, 10240000]  # 4KB到40MB

for size in sizes:
    data = torch.randn(size).npu()
    
    # 预热
    for _ in range(10):
        allreduce_shmem(data.clone())
    
    # 正式测试
    torch.npu.synchronize()
    t0 = time.time()
    for _ in range(100):
        allreduce_shmem(data.clone())
    torch.npu.synchronize()
    elapsed = (time.time() - t0) / 100 * 1000  # ms
    
    bandwidth = size * 4 / (elapsed / 1000) / 1e9  # GB/s
    print(f"Size: {size*4/1e6:.1f}MB, Time: {elapsed:.3f}ms, BW: {bandwidth:.1f}GB/s")

# 典型输出(4卡NPU):
# Size: 0.0MB, Time: 0.012ms, BW: 0.3GB/s
# Size: 0.4MB, Time: 0.018ms, BW: 22.8GB/s
# Size: 4.0MB, Time: 0.045ms, BW: 89.1GB/s
# Size: 40.0MB, Time: 0.32ms, BW: 125.0GB/s

shmem.finalize()

代码讲解:shmem的put/get是单边操作,不需要接收方参与。barrier是同步点,确保所有卡完成写操作后再读。AllReduce的实现思路是:每个卡把自己的数据广播给其他卡,然后从其他卡拉取数据本地求和。大消息(40MB)带宽达到125GB/s,是PCIe(16GB/s)的7.8倍

性能对比

测试环境:Ascend 910 × 4同节点,CANN 8.0。

数据量 PCIe通信 shmem共享内存 加速比
4KB 0.15ms 0.012ms 12.5x
400KB 0.25ms 0.018ms 13.9x
4MB 1.2ms 0.045ms 26.7x
40MB 4.5ms 0.32ms 14.1x

shmem比PCIe快10-25倍,小消息优势更明显(延迟低),大消息带宽优势也显著(125GB/s vs 16GB/s)。

适用场景

场景 推荐通信方式 原因
同节点多卡 shmem 延迟低、带宽高
跨节点通信 HCCL/hixl shmem不支持跨节点
参数服务器 hixl PS架构需要单边通信
梯度同步 shmem/HCCL 同节点用shmem,跨节点用HCCL

踩坑实录

坑1:共享内存没对齐

现象shmem_put报错Address not aligned

原因:shmem要求内存地址按64字节对齐。

解决:用shmem_malloc分配,自动对齐。

# 错误:普通分配
buffer = torch.empty(1024).npu()
shmem.put(dest=1, source=buffer.data_ptr(), ...)  # 可能报错

# 正确:用shmem_malloc
buffer_ptr = shmem.malloc(1024 * 4)
buffer = torch.from_buffer(buffer_ptr, ...)

坑2:忘记barrier导致数据竞争

现象:AllReduce结果不正确,偶尔出现随机值。

原因put是异步的,没等完成就读,读到的是旧数据。

解决:每次通信后加barrier

# 错误:没barrier
shmem.put(dest=1, ...)
# 可能还没写完,就开始下一轮

# 正确:加barrier
shmem.put(dest=1, ...)
shmem.barrier()  # 确保写完

坑3:跨节点使用shmem

现象:多节点训练时shmem报错PE not found

原因:shmem只支持同节点内的卡通信,不支持跨节点。

解决:跨节点用HCCL。

if is_same_node(rank1, rank2):
    use_shmem()
else:
    use_hccl()

结尾

shmem住在CANN五层架构第4层HCCL集合通信库上游,通过共享内存实现同节点内NPU卡的零拷贝通信,比PCIe快10-25倍,带宽达125GB/s

适用场景:同节点多卡分布式训练、需要低延迟通信的并行算法。跨节点通信请用HCCL或hixl。

参考仓库

shmem 共享内存通信
hccl 集合通信库
hixl 单边通信库
torchtitan-npu 分布式训练

Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐