昇腾CANN cann-recipes-train 实战:ZeRO 分布式训练的三阶段显存优化与 NPU 梯度累积
·
LLM 分布式训练的核心瓶颈不是算力,是显存。一个 Llama-7B 的 optimizer states 吃掉 28GB(Adam 的 momentum+variance,FP32 存两份,4×参数量),比模型权重 14GB 还大一倍。单卡跑不了 7B?不是算力不够,是显存放不下。
cann-recipes-train 提供了 ZeRO(Zero Redundancy Optimizer)的三个阶段实现,配合 hccl AllGather/ReduceScatter 通信原语,把 optimizer states、梯度、参数分片到多卡,单卡显存从 42GB 压到 6GB。
ZeRO-1: 切分 Optimizer States
# cann-recipes-train/zero/zero_stage1.py
#
# ZeRO-1: 仅切分优化器状态(Adam momentum + variance)
# 每张卡只存 1/N 的 optimizer states,其余 part 从其他卡广播
# 通信量: AllReduce gradients + Broadcast states = O(1) per parameter
import torch
import torch_npu
import torch.distributed as dist
from torch.distributed import ReduceOp
class ZeRO1Optimizer:
"""
ZeRO Stage 1: 切分 Adam optimizer states
标准 Adam: 每张卡存储完整 states
param: [D] FP16 = 2D bytes
exp_avg: [D] FP32 = 4D bytes }
exp_avg_sq: [D] FP32 = 4D bytes } optimizer states
ZeRO-1 Adam: 每张卡只存储 1/N states
param: [D/N] FP16 = 2D/N bytes
exp_avg: [D/N] FP32 = 4D/N bytes
exp_avg_sq: [D/N] FP32 = 4D/N bytes
"""
def __init__(self, params, lr=1e-4, betas=(0.9, 0.999), eps=1e-8,
world_size=None, rank=None):
self.world_size = dist.get_world_size() if world_size is None else world_size
self.rank = dist.get_rank() if rank is None else rank
self.lr = lr
self.betas = betas
self.eps = eps
# 拉平所有参数为一个 1D tensor(所有层的权重 concat)
self.param_groups = list(params)
self.total_params = sum(p.numel() for p in params)
# 每张卡分到的参数索引范围
chunk_size = (self.total_params + self.world_size - 1) // self.world_size
self.param_start = self.rank * chunk_size
self.param_end = min(self.param_start + chunk_size, self.total_params)
# 仅为本卡分配 optimizer states
local_size = self.param_end - self.param_start
self.exp_avg = torch.zeros(local_size, dtype=torch.float32, device="npu")
self.exp_avg_sq = torch.zeros(local_size, dtype=torch.float32, device="npu")
self.step_count = 0
def step(self):
"""
ZeRO-1 优化器步骤:
1. AllReduce gradients(将各卡的梯度求和,得到完整梯度)
2. 每卡只更新自己的 optimizer states 部分
3. Broadcast 各卡的参数 part → 完整参数
"""
self.step_count += 1
# Step 1: 拉平所有梯度
flat_grad = self._flatten_gradients()
# Step 2: AllReduce 梯度(各卡拿到完整梯度)
dist.all_reduce(flat_grad, op=ReduceOp.SUM, group=dist.group.WORLD)
flat_grad /= self.world_size # 平均
# Step 3: 每卡只更新自己的部分
local_grad = flat_grad[self.param_start:self.param_end]
# Adam update(只对本卡负责的参数部分)
bias_correction1 = 1 - self.betas[0] ** self.step_count
bias_correction2 = 1 - self.betas[1] ** self.step_count
self.exp_avg.mul_(self.betas[0]).add_(local_grad, alpha=1 - self.betas[0])
self.exp_avg_sq.mul_(self.betas[1]).addcmul_(
local_grad, local_grad, value=1 - self.betas[1]
)
denom = (
self.exp_avg_sq.sqrt() / (bias_correction2 ** 0.5)
).add_(self.eps)
step_size = self.lr / bias_correction1
local_update = self.exp_avg / denom * (-step_size)
# Step 4: 收集所有卡的参数更新
full_update = torch.zeros(self.total_params, dtype=torch.float32, device="npu")
full_update[self.param_start:self.param_end] = local_update
# AllGather: 每卡把自己的更新广播给所有卡
# 注意: 这里用 AllGather 收集所有 part → 每卡都有完整参数
gathered = [torch.zeros_like(local_update) for _ in range(self.world_size)]
dist.all_gather(gathered, local_update, group=dist.group.WORLD)
full_update = torch.cat(gathered)
# Step 5: 更新参数
self._apply_update(full_update)
def _flatten_gradients(self):
"""拉平所有参数的梯度为 1D"""
grads = []
for p in self.param_groups:
if p.grad is not None:
grads.append(p.grad.data.view(-1).float())
else:
grads.append(torch.zeros(p.numel(), dtype=torch.float32, device="npu"))
return torch.cat(grads)
def _apply_update(self, update):
"""将更新写回各层参数"""
offset = 0
for p in self.param_groups:
numel = p.numel()
p.data.add_(update[offset:offset+numel].view_as(p))
offset += numel
# 显存对比(Llama-7B, 8卡 Atlas 900):
# 标准 Adam: 42 GB per GPU (14B params + 28B states)
# ZeRO-1: 38.5 GB per GPU (14B params + 3.5B states) ← 省 3.5GB,远不够
# 仍然超单卡 32GB HBM → 需要 ZeRO-2 或 ZeRO-3
ZeRO-2: 附加切分梯度
# cann-recipes-train/zero/zero_stage2.py
#
# ZeRO-2: 切分 optimizer states + gradients
# 每卡只保存自己负责的梯度片段
# forward 后 AllReduce 改为 ReduceScatter → 每卡只拿自己需要的梯度
class ZeRO2Optimizer(ZeRO1Optimizer):
"""
ZeRO Stage 2: optimizer states + gradients 分片
执行流:
1. backward 计算本地梯度
2. ReduceScatter: 每卡只拿到自己负责的梯度 sum
3. 更新本地 optimizer states + 参数
4. AllGather: 恢复完整参数供下一次 forward
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.local_grad_buffer = torch.zeros(
self.param_end - self.param_start,
dtype=torch.float32, device="npu"
)
def reduce_gradients(self):
"""
ReduceScatter gradients
替代 ZeRO-1 的 AllReduce + 本地切片
直接一步完成: sum(reduce) + 分发(scatter) → 通信量减半
AllReduce: 每卡发 D, 收 D → 总通信 2D
ReduceScatter: 每卡发 D/N, 收 D/N → 总通信 2D/N (N×)
"""
flat_grad = self._flatten_gradients()
# 将梯度切分为 N 份
chunk_size = (self.total_params + self.world_size - 1) // self.world_size
local_chunk = flat_grad[self.param_start:self.param_end].clone()
# ReduceScatter: sum + scatter 合并为一步
dist.reduce_scatter_tensor(
self.local_grad_buffer,
flat_grad,
op=ReduceOp.SUM,
group=dist.group.WORLD
)
self.local_grad_buffer /= self.world_size
def step(self):
"""ZeRO-2 的完整 step"""
self.step_count += 1
# Step 1: ReduceScatter gradients
self.reduce_gradients()
# Step 2: 更新本地 optimizer states + 参数
bias_correction1 = 1 - self.betas[0] ** self.step_count
bias_correction2 = 1 - self.betas[1] ** self.step_count
self.exp_avg.mul_(self.betas[0]).add_(
self.local_grad_buffer, alpha=1 - self.betas[0]
)
self.exp_avg_sq.mul_(self.betas[1]).addcmul_(
self.local_grad_buffer, self.local_grad_buffer,
value=1 - self.betas[1]
)
denom = (
self.exp_avg_sq.sqrt() / (bias_correction2 ** 0.5)
).add_(self.eps)
step_size = self.lr / bias_correction1
local_update = self.exp_avg / denom * (-step_size)
# Step 3: AllGather 参数(恢复完整参数)
# 每卡广播自己的参数 part → 所有卡都有完整参数
self._allgather_params(local_update)
def _allgather_params(self, local_update):
"""AllGather 从各卡收集参数,恢复完整权重"""
gathered = [
torch.zeros_like(local_update) for _ in range(self.world_size)
]
dist.all_gather(gathered, local_update, group=dist.group.WORLD)
full_update = torch.cat(gathered)
self._apply_update(full_update)
# 显存对比(Llama-7B, 8卡):
# 标准 Adam: 42 GB
# ZeRO-1: 38.5 GB
# ZeRO-2: 21 GB (14B params + 1.75B states + 1.75B grads) ← 省了一半,仍超
ZeRO-3: 切分参数
# cann-recipes-train/zero/zero_stage3.py
#
# ZeRO-3: 切分 optimizer states + gradients + parameters
# 每卡在任何时刻只持有参数的 1/N
# forward/backward 时需要 AllGather 参数 → 用完立即释放
class ZeRO3Optimizer(ZeRO2Optimizer):
"""
ZeRO Stage 3: optimizer states + gradients + parameters 全分片
生命周期:
forward:
AllGather 参数 → 计算 → 释放参数(保留本地 chunk)
backward:
AllGather 参数 → 计算梯度 → ReduceScatter 梯度 → 释放参数
step:
更新本地 optimizer states + 参数 chunk → AllGather 广播
forward/backward 各一次 AllGather → 通信量 O(D)
step 一次 AllGather → 通信量 O(D)
总通信: 3D (vs ZeRO-2 的 D reduce_scatter + D allgather = 2D)
"""
def __init__(self, *args, partition_activations=True, **kwargs):
super().__init__(*args, **kwargs)
# 本卡持有的参数 chunk(初始化为零,需要时才 allgather)
self.local_param_chunk = torch.zeros(
self.param_end - self.param_start,
dtype=torch.float16, device="npu"
)
self.has_params = False
# 激活值检查点(可选,进一步省显存)
self.partition_activations = partition_activations
if partition_activations:
self.saved_activations = [] # 只存用于 backward 的最小激活值
def forward_pre_hook(self):
"""
forward 前: AllGather 参数
从各卡收集完整的参数 → 存储在 flat buffer
"""
# AllGather 参数
all_params = [
torch.zeros_like(self.local_param_chunk)
for _ in range(self.world_size)
]
dist.all_gather(all_params, self.local_param_chunk)
# 拼接为完整参数
self.full_params = torch.cat(all_params)
self.has_params = True
return self.full_params
def forward_post_hook(self):
"""
forward 后: 释放非本地参数(只保留自己的 chunk)
"""
if not self.partition_activations:
# 释放完整参数(保留本地 chunk)
self.full_params = None
self.has_params = False
def backward_pre_hook(self):
"""backward 前: 再次 AllGather 参数(因为 forward 后已释放)"""
if not self.has_params:
self.forward_pre_hook()
def backward_post_hook(self):
"""
backward 后: ReduceScatter 梯度 + 释放参数
"""
self.reduce_gradients() # 从 ZeRO-2 继承
self.full_params = None
self.has_params = False
def step(self):
"""ZeRO-3 的 step: 只需更新本地 chunk"""
self.step_count += 1
# Adam update(只更新本地 chunk,因为 params/states/grads 都本地)
bias_correction1 = 1 - self.betas[0] ** self.step_count
bias_correction2 = 1 - self.betas[1] ** self.step_count
self.exp_avg.mul_(self.betas[0]).add_(
self.local_grad_buffer, alpha=1 - self.betas[0]
)
self.exp_avg_sq.mul_(self.betas[1]).addcmul_(
self.local_grad_buffer, self.local_grad_buffer,
value=1 - self.betas[1]
)
denom = self.exp_avg_sq.sqrt() / (bias_correction2 ** 0.5) + self.eps
step_size = self.lr / bias_correction1
# 更新本地参数 chunk
self.local_param_chunk.add_(self.exp_avg / denom * (-step_size))
# 显存对比(Llama-7B, 8卡):
# 标准 Adam: 42 GB
# ZeRO-1: 38.5 GB
# ZeRO-2: 21 GB
# ZeRO-3: 6.2 GB (0GB params + 3.5GB states + 1.75GB grads + 1GB activations)
# ← 单卡 32GB 轻松装下
梯度累积——用计算换通信
# cann-recipes-train/gradient_accumulation.py
#
# 梯度累积: 多个 micro-batch 的前向/反向累积梯度,只做一次通信
# 目标: 用更少的 AllReduce 次数完成大 batch training
class GradientAccumulator:
"""
梯度累积器: N 个 micro-batch 累积后一次性通信
收益分析:
无累积: N 个 micro-batch → N 次 AllReduce → 每次通信 2D, 总 N×2D
有累积: N 个 micro-batch → 1 次 AllReduce → 总 1×2D
通信省 N× (但显存占用 N×,因为要存 N 个 batch 的激活值)
和 ZeRO 的关系:
ZeRO-2 用 ReduceScatter → 通信 = 2D/N per reduce
梯度累积 + ZeRO-2 → 通信 = 2D/N per accumulation_step
两者互补
"""
def __init__(self, model, accumulation_steps=8):
self.model = model
self.accumulation_steps = accumulation_steps
self.current_step = 0
# 累积梯度缓存(FP32 精度)
self.accumulated_grads = {}
for name, param in model.named_parameters():
if param.requires_grad:
self.accumulated_grads[name] = torch.zeros_like(
param.data, dtype=torch.float32, device="npu"
)
def backward(self, loss):
"""
反向传播 + 累积梯度
不在每次 micro-batch 后通信,只累加梯度
"""
loss.backward()
for name, param in self.model.named_parameters():
if param.requires_grad and param.grad is not None:
self.accumulated_grads[name].add_(param.grad.data.float())
param.grad = None # 释放 micro-batch 梯度
self.current_step += 1
def should_sync(self):
return self.current_step >= self.accumulation_steps
def sync_and_step(self, optimizer):
"""
梯度同步: 只在 accumulation_steps 完成后通信一次
"""
if self.current_step < self.accumulation_steps:
return False
# 将累积的梯度写回 model 参数
for name, param in self.model.named_parameters():
if param.requires_grad:
param.grad = self.accumulated_grads[name].half()
# 触发一次 AllReduce(或 ReduceScatter)
optimizer.step()
# 重置累积器
for grad in self.accumulated_grads.values():
grad.zero_()
self.current_step = 0
return True
组合策略:ZeRO-3 + 梯度累积 + 混合精度
# cann-recipes-train/zero/training_loop.py
#
# ZeRO-3 + Gradient Accumulation + Mixed Precision (AMP)
# 这是 cann-recipes-train 的默认训练 Recipe
def train_llama_7b(
model, train_loader, total_steps=100000,
micro_batch_size=1, accumulation_steps=32,
use_amp=True, use_zero3=True
):
"""
Llama-7B 训练主循环
配置: 8 张 Ascend 910, 单卡 32GB HBM
micro_batch=1, accumulation=32 → effective batch=32
显存: ~28GB per GPU(含激活值),trainable
AMP: FP16 forward/backward + FP32 master weights
"""
# ====== AMP 配置 ======
if use_amp:
scaler = torch.cuda.amp.GradScaler(
init_scale=2**16, # 初始 loss scale
growth_interval=2000, # 每 2000 步检查是否溢出
backoff_factor=0.5, # 溢出时缩小 scale
growth_factor=2.0 # 不溢出时放大 scale
)
# ====== ZeRO-3 优化器 ======
if use_zero3:
optimizer = ZeRO3Optimizer(
model.parameters(),
lr=3e-4, betas=(0.9, 0.95), eps=1e-8
)
# ====== 梯度累积 ======
accumulator = GradientAccumulator(model, accumulation_steps)
# ====== 训练循环 ======
model.train()
global_step = 0
for epoch in range(100):
for micro_step, batch in enumerate(train_loader):
# === forward (AMP + ZeRO-3 AllGather params) ===
if use_zero3:
optimizer.forward_pre_hook() # AllGather 参数
with torch.cuda.amp.autocast(dtype=torch.float16):
input_ids = batch["input_ids"].to("npu")
labels = batch["labels"].to("npu")
outputs = model(input_ids=input_ids, labels=labels)
loss = outputs.loss / accumulation_steps # 除以累积步数
if use_zero3:
optimizer.forward_post_hook() # 释放参数
# === backward (AMP unscaled + gradient accumulation) ===
if use_amp:
scaler.scale(loss).backward()
else:
loss.backward()
accumulator.backward(loss)
# === 累积完成 → 通信 + 优化器 step ===
if accumulator.should_sync():
if use_amp:
# unscale + 检查溢出 → reduce_scatter → step → allgather
scaler.unscale_(optimizer) # unscale 累积的梯度
# 检查 FP16 溢出
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=1.0
)
if not torch.isfinite(grad_norm):
print(f" ⚠️ Step {global_step}: gradient overflow, skipping")
scaler.step_skip(optimizer) # 跳过这次更新
scaler.update()
continue
scaler.step(optimizer)
scaler.update()
else:
accumulator.sync_and_step(optimizer)
global_step += 1
if global_step % 100 == 0:
print(f"Step {global_step}: loss={loss.item()*accumulation_steps:.4f}, "
f"grad_norm={grad_norm:.2f}")
if global_step >= total_steps:
break
return model
踩坑:ZeRO-3 + AMP Loss Scale 溢出——动态 scale 不够,小梯度被清为零
# ❌ 默认 loss_scale=2^16,但 Llama-7B 的某些层梯度只有 1e-8
# → 乘以 65536 后还是 < FP16 最小正数 (6.1e-5) → 下溢为零
# → 这些参数永远得不到更新 → 精度下降 0.5% per epoch
# ✅ per-tensor loss scale: 为每层设置独立的 scale
class PerTensorLossScaler:
"""
逐层 loss scale(基于各层的梯度量级)
大梯度层: scale=2^10(不溢出即可)
小梯度层: scale=2^24(兜住 1e-8 的量级)
"""
def __init__(self, model, base_scale=2**16):
self.layer_scales = {}
for name, param in model.named_parameters():
if param.requires_grad:
self.layer_scales[name] = base_scale
def update_scales(self, model):
"""根据最近 100 步的梯度统计更新 scale"""
for name, param in model.named_parameters():
if param.grad is None:
continue
grad_abs_max = param.grad.data.float().abs().max().item()
if grad_abs_max < 1e-7:
# 梯度太小 → 放大 scale
self.layer_scales[name] = min(
self.layer_scales[name] * 2, 2**24
)
elif grad_abs_max > 1e-3:
# 梯度足够大 → 正常 scale
self.layer_scales[name] = max(
self.layer_scales[name] // 2, 2**8
)
踩坑:ZeRO-3 的 AllGather 通信与 NPU 计算重叠——参数没到齐就开始 forward,导致 NaN
# ❌ 异步 AllGather 参数 → 不等通信完成就开始 forward
# forward 读到的参数一半是旧值、一半是新值 → NaN
# ✅ 使用 hccl stream 同步
class HCCLSyncAllGather:
"""
确保 AllGather 完成后再 forward
但利用全双工 HCCL:一边收集参数一边 overlap 数据预处理
"""
def __init__(self):
self.comm_stream = torch.npu.Stream() # 通信专用 stream
self.comp_stream = torch.npu.Stream() # 计算专用 stream
def allgather_params(self, local_chunk, world_size):
"""
异步 AllGather → 同步点确保参数完整
同时用 data_loader.next() 预取下一个 batch(隐藏通信延迟)
"""
with torch.npu.stream(self.comm_stream):
gathered = [torch.zeros_like(local_chunk) for _ in range(world_size)]
dist.all_gather(gathered, local_chunk)
full_params = torch.cat(gathered)
# 关键同步点: 等待 AllGather 完成
self.comm_stream.synchronize()
return full_params
def prefetch_next_batch(self, data_iter):
"""通信期间预取下一个 batch"""
with torch.npu.stream(self.comp_stream):
try:
return next(data_iter)
except StopIteration:
return None
cann-recipes-train 的 ZeRO 三阶段用通信换显存:ZeRO-1 切分 optimizer states(省 3.5GB)、ZeRO-2 附加切分梯度(省至 21GB)、ZeRO-3 全量分片参数(省至 6.2GB),AllReduce 改为 ReduceScatter+AllGather 对。配合梯度累积(8 micro-batch 合为一次通信)+ AMP 混合精度(FP16 前反向 + FP32 master weights + per-tensor loss scale),Llama-7B 在 8 卡 Atlas 900 上以 32 有效 batch size 稳定训练。踩坑:AMP Loss Scale 梯度下溢(1e-8→scale=2^24 兜底)、ZeRO-3 AllGather 异步 NaN→hccl stream 同步+data prefetch 隐藏延迟。
更多推荐


所有评论(0)