多机多卡训练里,不是所有通信场景都能用集合通信原语解决。模型并行里的细粒度通信、自定义的AllGather变种、跨节点的部分聚合——这些场景需要更灵活的通信能力。hcomm就是昇腾CANN里提供这套能力的仓库。

hcomm在CANN里的位置

hcomm是昇腾CANN开源社区里的扩展通信库,和hccl、hixl、ascend-boost-comm这些仓库并列,属于"通信与扩展仓库"这一类。

从依赖关系来看:

hccl(集合通信)→ hcomm(扩展通信)→ 昇腾驱动层
     ↑                          ↑
  PyTorch DDP              模型并行自定义通信
  Megatron-LM               MoE Expert 并行

hccl提供的是标准的集合通信原语(AllReduce、AllGather等),接口和NCCL对齐,适合框架级别的分布式训练集成。hcomm在这些标准原语之上,提供更底层的点对点通信能力和自定义通信模式,支持在上层做更灵活的通信调度。

点对点通信的基本概念

hccl处理的是集合通信(多个节点之间的协同通信),hcomm处理的是点对点通信(两个节点之间的直接通信)。

点对点通信在昇腾NPU上有几个典型场景:

场景1:流水线并行里的参数传递

把模型的不同层切分到不同的NPU上,数据在流水线里流过每个NPU的时候,每个NPU需要把激活值或梯度发给下一个NPU。这个传递是两个相邻节点之间的点对点通信。

场景2:模型并行里的权重分片

把模型的权重切分到不同的NPU上,每个NPU计算完自己分片的权重之后,需要把结果发给其他NPU做汇聚。这个汇聚可以是AllGather(hccl),也可以是自定义的分片交换模式(hcomm)。

场景3:MoE的Expert并行

MoE(Mixture of Experts)模型里,每个Expert在不同的NPU上,某个Expert处理完一个token之后,需要把这个token的激活值发给所有其他Expert做聚合。这个发送是一对多的点对点通信,需要hcomm的灵活控制。

代码示例:用hcomm做点对点通信

hcomm的Python接口设计得很直观。下面给一个点对点通信的示例:

# hcomm 点对点通信示例:流水线并行里的数据传递
import torch
import torch.distributed as dist
import torch_npu
import hcomm  # hcomm 的 Python 绑定

# 初始化 hcomm(需要在 hccl 之后初始化)
hcomm.init()

# 获取当前进程的 rank 和 world size
rank = dist.get_rank()
world_size = dist.get_world_size()
device = torch.device("npu:0")

# 构造一个流水线通信组:相邻的 NPU 组成一对
# rank=0 发送数据给 rank=1,rank=2 发送数据给 rank=3,以此类推
send_rank = rank
recv_rank = (rank + 1) % world_size
if rank % 2 == 1:  # 奇数 rank 作为接收端
    send_rank, recv_rank = recv_rank, send_rank

# 构造发送/接收的 tensor
# 模拟流水线里传递的激活值(hidden_size=4096)
hidden_size = 4096
batch_size = 8
send_tensor = torch.randn(batch_size, hidden_size, dtype=torch.float16, device=device)
recv_tensor = torch.empty_like(send_tensor)

# hcomm 点对点发送/接收
# send/recv 是阻塞调用,等待通信完成
hcomm.send(send_tensor, dst=recv_rank)
hcomm.recv(recv_tensor, src=send_rank)

print(f"Rank {rank}: received tensor from rank {send_rank}, shape={recv_tensor.shape}")

这段代码展示了hcomm最基础的点对点通信模式:send和recv成对出现,阻塞等待。hcomm的send/recv支持两种语义:阻塞语义(等数据真正收到再返回)和非阻塞语义(发出请求后立即返回,用wait等结果)。

非阻塞通信和通信流水线

hcomm的真正价值在于支持非阻塞通信,可以和计算重叠。在流水线并行里,数据传递和计算重叠可以显著减少流水线的气泡(bubble)。

# hcomm 非阻塞通信:和计算重叠
import torch
import hcomm

def pipeline_forward(input_tensor, send_rank, recv_rank):
    """流水线并行的前向传播,数据传递和计算重叠"""
    hidden_size = input_tensor.shape[-1]
    device = input_tensor.device

    # 准备接收 buffer(非阻塞,提前分配)
    recv_tensor = torch.empty_like(input_tensor)
    recv_req = hcomm.irecv(recv_tensor, src=send_rank)

    # 本地计算(不依赖对方数据的部分)
    local_output = compute_local_layers(input_tensor)

    # 等待接收完成(非阻塞,拿到 request 后等待)
    recv_req.wait()

    # 合并本地计算结果和接收到的数据
    output = local_output + recv_tensor

    # 发送自己的输出给下一个节点(非阻塞)
    send_tensor = output.detach().clone()
    send_req = hcomm.isend(send_tensor, dst=recv_rank)

    # 等待发送完成
    send_req.wait()

    return output

这段代码的核心是通信和计算重叠:在本地计算进行的同时,接收对方的数据;本地计算完成后,合并数据,再发送自己的输出。hcomm的irecv和isend(非阻塞接口)让这个重叠成为可能。

自定义通信模式

hcomm还支持完全自定义的通信模式。下面给一个自定义AllGather的实现示例——这个AllGather不是hccl的标准版本,而是针对特定数据结构的优化版本:

# hcomm 自定义 AllGather:针对稀疏数据的优化版本
import torch
import hcomm

def sparse_allgather(tensor_list, group):
    """
    自定义的 AllGather,针对稀疏 tensor 做了优化
    场景:每个 rank 只有一小部分非零数据,全量 gather 后仍是稀疏的
    """
    world_size = group.size()
    rank = group.rank()
    device = tensor_list[0].device

    # 步骤1:先做一次 reduce-scatter,把数据分散开
    scattered = []
    num_tensors = len(tensor_list)
    chunk_size = num_tensors // world_size
    for i in range(world_size):
        start = i * chunk_size
        end = start + chunk_size if i < world_size - 1 else num_tensors
        chunk = torch.cat(tensor_list[start:end], dim=0).to(device)
        scattered.append(chunk)

    # 每个 rank 拿到自己的 chunk
    local_chunk = scattered[rank]

    # 步骤2:用 hcomm 做 ring 式的 AllGather
    result = [torch.empty_like(local_chunk) for _ in range(world_size)]
    result[rank] = local_chunk

    for step in range(1, world_size):
        src = (rank - step + world_size) % world_size
        dst = (rank + step) % world_size
        tmp = torch.empty_like(local_chunk)
        hcomm.send(local_chunk, dst=dst)
        hcomm.recv(tmp, src=src)
        result[dst] = tmp

    # 步骤3:拼接所有 chunk
    return torch.cat(result, dim=0)

这个自定义AllGather的优化点是:稀疏数据的传输量比标准AllGather小很多。标准AllGather会把所有数据广播到所有节点,数据量是O(N×size);稀疏AllGather只传输非零数据段,实际传输量可能只有标准版的10%。

和MoE集成的实际案例

MoE(Mixture of Experts)模型的大规模训练里,hcomm的使用非常密集。以Mixtral-8x7B为例,每个MoE层有8个Expert,每个Expert分布在不同的NPU上。每个token经过gate网络选出top-2的Expert之后,需要把token的激活值发给这两个Expert,计算完再发回来。

这个通信模式不是标准的集合通信(不是AllReduce、AllGather),而是一对多的点对点通信:一个token可能从NPU 0发到NPU 3和NPU 7。hccl处理不了这种灵活的一对多通信,需要hcomm。

# MoE 层里的 hcomm 通信:token 到 Expert 的映射
import torch
import hcomm
import torch.distributed as dist

def moe_dispatch tokens, gate_scores, expert_count=8):
    """
    MoE 的 dispatch 阶段:把 token 分发到对应的 Expert
    输入:tokens [batch, seq, hidden]
    gate_scores:[batch, seq, expert_count] 每个 Expert 的得分
    """
    batch_size, seq_len, hidden = tokens.shape
    device = tokens.device

    # 选 top-2 Expert
    topk_scores, topk_indices = torch.topk(gate_scores, k=2, dim=-1)
    topk_indices = topk_indices.reshape(-1)  # [batch*seq, 2]

    # 每个 token 需要发送到对应的 Expert
    # expert_rank[i] 是 Expert i 所在的 NPU rank
    # 这个映射由模型并行策略决定,这里假设已知
    dispatched = []
    for expert_id in range(expert_count):
        # 找出需要发给这个 Expert 的 token indices
        mask = (topk_indices == expert_id).any(dim=-1)
        if not mask.any():
            continue

        # 提取对应的 token
        tokens_to_send = tokens.reshape(-1, hidden)[mask]

        # 用 hcomm 发送到 Expert 所在的 NPU
        expert_rank = expert_rank_mapping[expert_id]
        hcomm.send(tokens_to_send, dst=expert_rank)

    return topk_indices  # 返回映射,供后续 combine 阶段使用

踩过的几个坑

第一个坑是通信组的生命周期管理。hcomm的通信组需要手动创建和销毁,如果通信组在用的时候被销毁了,会导致通信失败。解法是把通信组的创建和销毁放在try/finally里,确保无论计算是否成功,通信组都会被正确清理。

第二个坑是非阻塞通信的错误传播。isend和irecv是非阻塞的,发出去之后如果出错(比如对方节点挂了),错误不会被立即捕获。解法是在wait的时候检查返回值,并且定期检查通信状态。

第三个坑是通信和计算的重叠边界。没搞清楚哪些计算可以重叠、哪些必须等待通信完成就开始写代码,会导致数据依赖出错。解法是先画通信图,明确哪些计算依赖通信结果,再开始写代码。

总结

hcomm是昇腾CANN里提供扩展通信能力的仓库,支持点对点通信、非阻塞通信和自定义通信模式。它填补了hccl的标准集合通信原语覆盖不到的场景,比如流水线并行里的参数传递、MoE的Expert并行通信、自定义的稀疏数据AllGather等。

在实际的大模型训练项目里,hccl和hcomm通常是配合使用的:框架级别的梯度同步用hccl(标准AllReduce),模型并行里的细粒度通信用hcomm(点对点和自定义通信模式)。

如果你正在做昇腾上的模型并行开发,建议把hcomm的接口文档过一遍。它提供的灵活性在复杂的并行策略里非常重要。

仓库地址:https://atomgit.com/cann/hcomm

Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐