LLM 分布式训练的核心瓶颈不是算力,是显存。一个 Llama-7B 的 optimizer states 吃掉 28GB(Adam 的 momentum+variance,FP32 存两份,4×参数量),比模型权重 14GB 还大一倍。单卡跑不了 7B?不是算力不够,是显存放不下。

cann-recipes-train 提供了 ZeRO(Zero Redundancy Optimizer)的三个阶段实现,配合 hccl AllGather/ReduceScatter 通信原语,把 optimizer states、梯度、参数分片到多卡,单卡显存从 42GB 压到 6GB。

ZeRO-1: 切分 Optimizer States

# cann-recipes-train/zero/zero_stage1.py
#
# ZeRO-1: 仅切分优化器状态(Adam momentum + variance)
# 每张卡只存 1/N 的 optimizer states,其余 part 从其他卡广播
# 通信量: AllReduce gradients + Broadcast states = O(1) per parameter

import torch
import torch_npu
import torch.distributed as dist
from torch.distributed import ReduceOp

class ZeRO1Optimizer:
    """
    ZeRO Stage 1: 切分 Adam optimizer states

    标准 Adam: 每张卡存储完整 states
      param:        [D] FP16 = 2D bytes
      exp_avg:      [D] FP32 = 4D bytes }
      exp_avg_sq:   [D] FP32 = 4D bytes } optimizer states

    ZeRO-1 Adam: 每张卡只存储 1/N states
      param:        [D/N] FP16 = 2D/N bytes
      exp_avg:      [D/N] FP32 = 4D/N bytes
      exp_avg_sq:   [D/N] FP32 = 4D/N bytes
    """

    def __init__(self, params, lr=1e-4, betas=(0.9, 0.999), eps=1e-8,
                 world_size=None, rank=None):
        self.world_size = dist.get_world_size() if world_size is None else world_size
        self.rank = dist.get_rank() if rank is None else rank
        self.lr = lr
        self.betas = betas
        self.eps = eps

        # 拉平所有参数为一个 1D tensor(所有层的权重 concat)
        self.param_groups = list(params)
        self.total_params = sum(p.numel() for p in params)

        # 每张卡分到的参数索引范围
        chunk_size = (self.total_params + self.world_size - 1) // self.world_size
        self.param_start = self.rank * chunk_size
        self.param_end = min(self.param_start + chunk_size, self.total_params)

        # 仅为本卡分配 optimizer states
        local_size = self.param_end - self.param_start
        self.exp_avg = torch.zeros(local_size, dtype=torch.float32, device="npu")
        self.exp_avg_sq = torch.zeros(local_size, dtype=torch.float32, device="npu")
        self.step_count = 0

    def step(self):
        """
        ZeRO-1 优化器步骤:
        1. AllReduce gradients(将各卡的梯度求和,得到完整梯度)
        2. 每卡只更新自己的 optimizer states 部分
        3. Broadcast 各卡的参数 part → 完整参数
        """
        self.step_count += 1

        # Step 1: 拉平所有梯度
        flat_grad = self._flatten_gradients()

        # Step 2: AllReduce 梯度(各卡拿到完整梯度)
        dist.all_reduce(flat_grad, op=ReduceOp.SUM, group=dist.group.WORLD)
        flat_grad /= self.world_size  # 平均

        # Step 3: 每卡只更新自己的部分
        local_grad = flat_grad[self.param_start:self.param_end]

        # Adam update(只对本卡负责的参数部分)
        bias_correction1 = 1 - self.betas[0] ** self.step_count
        bias_correction2 = 1 - self.betas[1] ** self.step_count

        self.exp_avg.mul_(self.betas[0]).add_(local_grad, alpha=1 - self.betas[0])
        self.exp_avg_sq.mul_(self.betas[1]).addcmul_(
            local_grad, local_grad, value=1 - self.betas[1]
        )

        denom = (
            self.exp_avg_sq.sqrt() / (bias_correction2 ** 0.5)
        ).add_(self.eps)

        step_size = self.lr / bias_correction1
        local_update = self.exp_avg / denom * (-step_size)

        # Step 4: 收集所有卡的参数更新
        full_update = torch.zeros(self.total_params, dtype=torch.float32, device="npu")
        full_update[self.param_start:self.param_end] = local_update

        # AllGather: 每卡把自己的更新广播给所有卡
        # 注意: 这里用 AllGather 收集所有 part → 每卡都有完整参数
        gathered = [torch.zeros_like(local_update) for _ in range(self.world_size)]
        dist.all_gather(gathered, local_update, group=dist.group.WORLD)
        full_update = torch.cat(gathered)

        # Step 5: 更新参数
        self._apply_update(full_update)

    def _flatten_gradients(self):
        """拉平所有参数的梯度为 1D"""
        grads = []
        for p in self.param_groups:
            if p.grad is not None:
                grads.append(p.grad.data.view(-1).float())
            else:
                grads.append(torch.zeros(p.numel(), dtype=torch.float32, device="npu"))
        return torch.cat(grads)

    def _apply_update(self, update):
        """将更新写回各层参数"""
        offset = 0
        for p in self.param_groups:
            numel = p.numel()
            p.data.add_(update[offset:offset+numel].view_as(p))
            offset += numel

# 显存对比(Llama-7B, 8卡 Atlas 900):
# 标准 Adam:  42 GB per GPU (14B params + 28B states)
# ZeRO-1:     38.5 GB per GPU (14B params + 3.5B states)  ← 省 3.5GB,远不够
# 仍然超单卡 32GB HBM → 需要 ZeRO-2 或 ZeRO-3

ZeRO-2: 附加切分梯度

# cann-recipes-train/zero/zero_stage2.py
#
# ZeRO-2: 切分 optimizer states + gradients
# 每卡只保存自己负责的梯度片段
# forward 后 AllReduce 改为 ReduceScatter → 每卡只拿自己需要的梯度

class ZeRO2Optimizer(ZeRO1Optimizer):
    """
    ZeRO Stage 2: optimizer states + gradients 分片

    执行流:
    1. backward 计算本地梯度
    2. ReduceScatter: 每卡只拿到自己负责的梯度 sum
    3. 更新本地 optimizer states + 参数
    4. AllGather: 恢复完整参数供下一次 forward
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.local_grad_buffer = torch.zeros(
            self.param_end - self.param_start,
            dtype=torch.float32, device="npu"
        )

    def reduce_gradients(self):
        """
        ReduceScatter gradients
        替代 ZeRO-1 的 AllReduce + 本地切片
        直接一步完成: sum(reduce) + 分发(scatter) → 通信量减半

        AllReduce:   每卡发 D, 收 D  → 总通信 2D
        ReduceScatter: 每卡发 D/N, 收 D/N → 总通信 2D/N  (N×)
        """
        flat_grad = self._flatten_gradients()

        # 将梯度切分为 N 份
        chunk_size = (self.total_params + self.world_size - 1) // self.world_size
        local_chunk = flat_grad[self.param_start:self.param_end].clone()

        # ReduceScatter: sum + scatter 合并为一步
        dist.reduce_scatter_tensor(
            self.local_grad_buffer,
            flat_grad,
            op=ReduceOp.SUM,
            group=dist.group.WORLD
        )
        self.local_grad_buffer /= self.world_size

    def step(self):
        """ZeRO-2 的完整 step"""
        self.step_count += 1

        # Step 1: ReduceScatter gradients
        self.reduce_gradients()

        # Step 2: 更新本地 optimizer states + 参数
        bias_correction1 = 1 - self.betas[0] ** self.step_count
        bias_correction2 = 1 - self.betas[1] ** self.step_count

        self.exp_avg.mul_(self.betas[0]).add_(
            self.local_grad_buffer, alpha=1 - self.betas[0]
        )
        self.exp_avg_sq.mul_(self.betas[1]).addcmul_(
            self.local_grad_buffer, self.local_grad_buffer,
            value=1 - self.betas[1]
        )

        denom = (
            self.exp_avg_sq.sqrt() / (bias_correction2 ** 0.5)
        ).add_(self.eps)

        step_size = self.lr / bias_correction1
        local_update = self.exp_avg / denom * (-step_size)

        # Step 3: AllGather 参数(恢复完整参数)
        # 每卡广播自己的参数 part → 所有卡都有完整参数
        self._allgather_params(local_update)

    def _allgather_params(self, local_update):
        """AllGather 从各卡收集参数,恢复完整权重"""
        gathered = [
            torch.zeros_like(local_update) for _ in range(self.world_size)
        ]
        dist.all_gather(gathered, local_update, group=dist.group.WORLD)
        full_update = torch.cat(gathered)
        self._apply_update(full_update)

# 显存对比(Llama-7B, 8卡):
# 标准 Adam: 42 GB
# ZeRO-1:    38.5 GB
# ZeRO-2:    21 GB (14B params + 1.75B states + 1.75B grads)  ← 省了一半,仍超

ZeRO-3: 切分参数

# cann-recipes-train/zero/zero_stage3.py
#
# ZeRO-3: 切分 optimizer states + gradients + parameters
# 每卡在任何时刻只持有参数的 1/N
# forward/backward 时需要 AllGather 参数 → 用完立即释放

class ZeRO3Optimizer(ZeRO2Optimizer):
    """
    ZeRO Stage 3: optimizer states + gradients + parameters 全分片

    生命周期:
    forward:
      AllGather 参数 → 计算 → 释放参数(保留本地 chunk)
    backward:
      AllGather 参数 → 计算梯度 → ReduceScatter 梯度 → 释放参数
    step:
      更新本地 optimizer states + 参数 chunk → AllGather 广播

    forward/backward 各一次 AllGather → 通信量 O(D)
    step 一次 AllGather → 通信量 O(D)
    总通信: 3D (vs ZeRO-2 的 D reduce_scatter + D allgather = 2D)
    """

    def __init__(self, *args, partition_activations=True, **kwargs):
        super().__init__(*args, **kwargs)

        # 本卡持有的参数 chunk(初始化为零,需要时才 allgather)
        self.local_param_chunk = torch.zeros(
            self.param_end - self.param_start,
            dtype=torch.float16, device="npu"
        )
        self.has_params = False

        # 激活值检查点(可选,进一步省显存)
        self.partition_activations = partition_activations
        if partition_activations:
            self.saved_activations = []  # 只存用于 backward 的最小激活值

    def forward_pre_hook(self):
        """
        forward 前: AllGather 参数
        从各卡收集完整的参数 → 存储在 flat buffer
        """
        # AllGather 参数
        all_params = [
            torch.zeros_like(self.local_param_chunk)
            for _ in range(self.world_size)
        ]
        dist.all_gather(all_params, self.local_param_chunk)

        # 拼接为完整参数
        self.full_params = torch.cat(all_params)
        self.has_params = True

        return self.full_params

    def forward_post_hook(self):
        """
        forward 后: 释放非本地参数(只保留自己的 chunk)
        """
        if not self.partition_activations:
            # 释放完整参数(保留本地 chunk)
            self.full_params = None
            self.has_params = False

    def backward_pre_hook(self):
        """backward 前: 再次 AllGather 参数(因为 forward 后已释放)"""
        if not self.has_params:
            self.forward_pre_hook()

    def backward_post_hook(self):
        """
        backward 后: ReduceScatter 梯度 + 释放参数
        """
        self.reduce_gradients()  # 从 ZeRO-2 继承
        self.full_params = None
        self.has_params = False

    def step(self):
        """ZeRO-3 的 step: 只需更新本地 chunk"""
        self.step_count += 1

        # Adam update(只更新本地 chunk,因为 params/states/grads 都本地)
        bias_correction1 = 1 - self.betas[0] ** self.step_count
        bias_correction2 = 1 - self.betas[1] ** self.step_count

        self.exp_avg.mul_(self.betas[0]).add_(
            self.local_grad_buffer, alpha=1 - self.betas[0]
        )
        self.exp_avg_sq.mul_(self.betas[1]).addcmul_(
            self.local_grad_buffer, self.local_grad_buffer,
            value=1 - self.betas[1]
        )

        denom = self.exp_avg_sq.sqrt() / (bias_correction2 ** 0.5) + self.eps
        step_size = self.lr / bias_correction1

        # 更新本地参数 chunk
        self.local_param_chunk.add_(self.exp_avg / denom * (-step_size))

# 显存对比(Llama-7B, 8卡):
# 标准 Adam:   42 GB
# ZeRO-1:      38.5 GB
# ZeRO-2:      21 GB
# ZeRO-3:      6.2 GB (0GB params + 3.5GB states + 1.75GB grads + 1GB activations)
#                 ← 单卡 32GB 轻松装下

梯度累积——用计算换通信

# cann-recipes-train/gradient_accumulation.py
#
# 梯度累积: 多个 micro-batch 的前向/反向累积梯度,只做一次通信
# 目标: 用更少的 AllReduce 次数完成大 batch training

class GradientAccumulator:
    """
    梯度累积器: N 个 micro-batch 累积后一次性通信

    收益分析:
    无累积: N 个 micro-batch → N 次 AllReduce → 每次通信 2D, 总 N×2D
    有累积: N 个 micro-batch → 1 次 AllReduce → 总 1×2D
    通信省 N× (但显存占用 N×,因为要存 N 个 batch 的激活值)

    和 ZeRO 的关系:
    ZeRO-2 用 ReduceScatter → 通信 = 2D/N per reduce
    梯度累积 + ZeRO-2 → 通信 = 2D/N per accumulation_step
    两者互补
    """

    def __init__(self, model, accumulation_steps=8):
        self.model = model
        self.accumulation_steps = accumulation_steps
        self.current_step = 0

        # 累积梯度缓存(FP32 精度)
        self.accumulated_grads = {}
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.accumulated_grads[name] = torch.zeros_like(
                    param.data, dtype=torch.float32, device="npu"
                )

    def backward(self, loss):
        """
        反向传播 + 累积梯度
        不在每次 micro-batch 后通信,只累加梯度
        """
        loss.backward()

        for name, param in self.model.named_parameters():
            if param.requires_grad and param.grad is not None:
                self.accumulated_grads[name].add_(param.grad.data.float())
                param.grad = None  # 释放 micro-batch 梯度

        self.current_step += 1

    def should_sync(self):
        return self.current_step >= self.accumulation_steps

    def sync_and_step(self, optimizer):
        """
        梯度同步: 只在 accumulation_steps 完成后通信一次
        """
        if self.current_step < self.accumulation_steps:
            return False

        # 将累积的梯度写回 model 参数
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.grad = self.accumulated_grads[name].half()

        # 触发一次 AllReduce(或 ReduceScatter)
        optimizer.step()

        # 重置累积器
        for grad in self.accumulated_grads.values():
            grad.zero_()
        self.current_step = 0

        return True

组合策略:ZeRO-3 + 梯度累积 + 混合精度

# cann-recipes-train/zero/training_loop.py
#
# ZeRO-3 + Gradient Accumulation + Mixed Precision (AMP)
# 这是 cann-recipes-train 的默认训练 Recipe

def train_llama_7b(
    model, train_loader, total_steps=100000,
    micro_batch_size=1, accumulation_steps=32,
    use_amp=True, use_zero3=True
):
    """
    Llama-7B 训练主循环

    配置: 8 张 Ascend 910, 单卡 32GB HBM
    micro_batch=1, accumulation=32 → effective batch=32
    显存: ~28GB per GPU(含激活值),trainable

    AMP: FP16 forward/backward + FP32 master weights
    """
    # ====== AMP 配置 ======
    if use_amp:
        scaler = torch.cuda.amp.GradScaler(
            init_scale=2**16,      # 初始 loss scale
            growth_interval=2000,   # 每 2000 步检查是否溢出
            backoff_factor=0.5,     # 溢出时缩小 scale
            growth_factor=2.0       # 不溢出时放大 scale
        )

    # ====== ZeRO-3 优化器 ======
    if use_zero3:
        optimizer = ZeRO3Optimizer(
            model.parameters(),
            lr=3e-4, betas=(0.9, 0.95), eps=1e-8
        )

    # ====== 梯度累积 ======
    accumulator = GradientAccumulator(model, accumulation_steps)

    # ====== 训练循环 ======
    model.train()
    global_step = 0

    for epoch in range(100):
        for micro_step, batch in enumerate(train_loader):
            # === forward (AMP + ZeRO-3 AllGather params) ===
            if use_zero3:
                optimizer.forward_pre_hook()  # AllGather 参数

            with torch.cuda.amp.autocast(dtype=torch.float16):
                input_ids = batch["input_ids"].to("npu")
                labels = batch["labels"].to("npu")

                outputs = model(input_ids=input_ids, labels=labels)
                loss = outputs.loss / accumulation_steps  # 除以累积步数

                if use_zero3:
                    optimizer.forward_post_hook()  # 释放参数

            # === backward (AMP unscaled + gradient accumulation) ===
            if use_amp:
                scaler.scale(loss).backward()
            else:
                loss.backward()

            accumulator.backward(loss)

            # === 累积完成 → 通信 + 优化器 step ===
            if accumulator.should_sync():
                if use_amp:
                    # unscale + 检查溢出 → reduce_scatter → step → allgather
                    scaler.unscale_(optimizer)  # unscale 累积的梯度

                    # 检查 FP16 溢出
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(), max_norm=1.0
                    )

                    if not torch.isfinite(grad_norm):
                        print(f"  ⚠️ Step {global_step}: gradient overflow, skipping")
                        scaler.step_skip(optimizer)  # 跳过这次更新
                        scaler.update()
                        continue

                    scaler.step(optimizer)
                    scaler.update()
                else:
                    accumulator.sync_and_step(optimizer)

                global_step += 1

                if global_step % 100 == 0:
                    print(f"Step {global_step}: loss={loss.item()*accumulation_steps:.4f}, "
                          f"grad_norm={grad_norm:.2f}")

            if global_step >= total_steps:
                break

    return model

踩坑:ZeRO-3 + AMP Loss Scale 溢出——动态 scale 不够,小梯度被清为零

# ❌ 默认 loss_scale=2^16,但 Llama-7B 的某些层梯度只有 1e-8
# → 乘以 65536 后还是 < FP16 最小正数 (6.1e-5) → 下溢为零
# → 这些参数永远得不到更新 → 精度下降 0.5% per epoch

# ✅ per-tensor loss scale: 为每层设置独立的 scale
class PerTensorLossScaler:
    """
    逐层 loss scale(基于各层的梯度量级)
    大梯度层: scale=2^10(不溢出即可)
    小梯度层: scale=2^24(兜住 1e-8 的量级)
    """

    def __init__(self, model, base_scale=2**16):
        self.layer_scales = {}
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.layer_scales[name] = base_scale

    def update_scales(self, model):
        """根据最近 100 步的梯度统计更新 scale"""
        for name, param in model.named_parameters():
            if param.grad is None:
                continue

            grad_abs_max = param.grad.data.float().abs().max().item()

            if grad_abs_max < 1e-7:
                # 梯度太小 → 放大 scale
                self.layer_scales[name] = min(
                    self.layer_scales[name] * 2, 2**24
                )
            elif grad_abs_max > 1e-3:
                # 梯度足够大 → 正常 scale
                self.layer_scales[name] = max(
                    self.layer_scales[name] // 2, 2**8
                )

踩坑:ZeRO-3 的 AllGather 通信与 NPU 计算重叠——参数没到齐就开始 forward,导致 NaN

# ❌ 异步 AllGather 参数 → 不等通信完成就开始 forward
# forward 读到的参数一半是旧值、一半是新值 → NaN

# ✅ 使用 hccl stream 同步
class HCCLSyncAllGather:
    """
    确保 AllGather 完成后再 forward
    但利用全双工 HCCL:一边收集参数一边 overlap 数据预处理
    """

    def __init__(self):
        self.comm_stream = torch.npu.Stream()  # 通信专用 stream
        self.comp_stream = torch.npu.Stream()  # 计算专用 stream

    def allgather_params(self, local_chunk, world_size):
        """
        异步 AllGather → 同步点确保参数完整
        同时用 data_loader.next() 预取下一个 batch(隐藏通信延迟)
        """
        with torch.npu.stream(self.comm_stream):
            gathered = [torch.zeros_like(local_chunk) for _ in range(world_size)]
            dist.all_gather(gathered, local_chunk)
            full_params = torch.cat(gathered)

        # 关键同步点: 等待 AllGather 完成
        self.comm_stream.synchronize()

        return full_params

    def prefetch_next_batch(self, data_iter):
        """通信期间预取下一个 batch"""
        with torch.npu.stream(self.comp_stream):
            try:
                return next(data_iter)
            except StopIteration:
                return None

cann-recipes-train 的 ZeRO 三阶段用通信换显存:ZeRO-1 切分 optimizer states(省 3.5GB)、ZeRO-2 附加切分梯度(省至 21GB)、ZeRO-3 全量分片参数(省至 6.2GB),AllReduce 改为 ReduceScatter+AllGather 对。配合梯度累积(8 micro-batch 合为一次通信)+ AMP 混合精度(FP16 前反向 + FP32 master weights + per-tensor loss scale),Llama-7B 在 8 卡 Atlas 900 上以 32 有效 batch size 稳定训练。踩坑:AMP Loss Scale 梯度下溢(1e-8→scale=2^24 兜底)、ZeRO-3 AllGather 异步 NaN→hccl stream 同步+data prefetch 隐藏延迟。

Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐