LLaMA-13B 在 8 张 NPU 上训练,batch_size=128、seq_len=4096 → 单卡需要 28GB 显存放激活(activation),但 Ascend 910 只有 56GB HBM——28GB 激活 + 13GB 参数 + 8GB 优化器状态 = 49GB,只剩 7GB 余量。如果 seq_len 涨到 8192,激活翻倍到 56GB → OOM。

两条路径:梯度累积(减少每步的激活量,多步后才同步)和 Gradient Checkpoint(不存所有激活,反向时重算)。单独用任一都不够——梯度累积解决不了单层激活 > 剩余 HBM 的问题,Checkpoint 的重算开销降低吞吐 30%。联合用才最优:梯度累积把 batch 切成 micro-batch,每个 micro-batch 只存 checkpoint 标记层的激活,中间层反向时重算。

梯度累积:用时间换显存

标准训练:一个 step = forward(128 samples) → backward → optimizer.step()。128 samples 的激活全存 HBM。

梯度累积:128 = 8 × 16 → forward(16 samples) → backward(16 samples) [不 step] → forward(16) → backward(16) [不 step] → …(重复 6 次)… → optimizer.step() [等效 batch=128]。

# cann-recipes-train/examples/gradient_accumulation.py

import torch
import torch_npu
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

def train_with_grad_accum(model, dataloader, optimizer, config):
    """
    config.accum_steps = 8       # 累积 8 步
    config.micro_batch_size = 16 # 每步 16 samples
    config.effective_batch = 128 # 等效 batch
    """
    model.train()
    optimizer.zero_grad()

    for step, batch in enumerate(dataloader):
        # 把 batch 切成 micro-batches
        micro_batches = batch.chunk(config.accum_steps)

        for i, micro_batch in enumerate(micro_batches):
            # 前向(只存当前 micro-batch 的激活)
            with torch.autocast("npu", dtype=torch.float16):
                loss = model(micro_batch) / config.accum_steps
                # ↑ 除以 accum_steps,让 loss 反向时的梯度是平均梯度
                #   等价于所有 micro-batch 的 loss 求和后反向
                #   ∇(loss/8) = ∇loss/8 → 8 步累积后 = ∇loss

            # 反向(梯度留在 param.grad 里,不清零)
            loss.backward()

            # 释放 micro-batch 的激活(让 HBM 可以给下一步用)
            del micro_batch, loss
            torch.npu.empty_cache()

        # 8 个 micro-batch 的梯度都累积在 param.grad → 一次 step
        # 梯度同步(FSDP AllReduce)也只做一次,省 7 次通信
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        optimizer.zero_grad()

梯度累积的通信收益:8 个 micro-batch 只做 1 次梯度 AllReduce(vs 8 次),省了 87.5% 的通信量。但显存收益有限——micro-batch 的激活在内存中,不在 HBM 中,只有当前 forward 的激活在 HBM。

Gradient Checkpoint:用计算换显存

标准训练:forward 把每层的激活存下来,backward 直接读。LLaMA-13B 的 40 个 Transformer 层,每层存 0.7GB 激活(batch=16, seq=4096)→ 28GB。

Checkpoint:只存 N 个「检查点」层的激活,其他层在 backward 时从检查点重算 forward。比如每 2 层存一个检查点 → 存 20 层(14GB),省 14GB。代价:每层额外一次 forward → 额外 50% 计算量(40 层 forward + backward = 80 层的工作量,with checkpoint 变成 40 + 40 + 20 = 100 层 = 25% 额外计算)。

# cann-recipes-train/examples/gradient_checkpointing.py

import torch.utils.checkpoint as checkpoint

class TransformerWithCheckpoint(nn.Module):
    def __init__(self, num_layers, hidden_dim, checkpoint_every=2):
        super().__init__()
        self.num_layers = num_layers
        self.checkpoint_every = checkpoint_every
        self.layers = nn.ModuleList([
            TransformerLayer(hidden_dim) for _ in range(num_layers)
        ])

    def forward(self, x):
        for layer_idx, layer in enumerate(self.layers):
            if layer_idx % self.checkpoint_every == 0:
                # 检查点层:forward 正常执行,激活存 HBM
                x = layer(x)
            else:
                # 非检查点层:用 checkpoint(不存激活,重算)
                x = checkpoint.checkpoint(
                    layer, x,
                    use_reentrant=False  # PyTorch 2.0+ 推荐
                )
        return x

# 配置
model = TransformerWithCheckpoint(
    num_layers=40,
    hidden_dim=5120,
    checkpoint_every=2  # 每 2 层一个检查点 → 20 个检查点
)

# 显存占用
# 无 checkpoint: 40 层 × 0.7GB = 28GB
# checkpoint_every=2: 20 层 × 0.7GB = 14GB(省 14GB)
# checkpoint_every=4: 10 层 × 0.7GB = 7GB(省 21GB,但重算 75%)

联合策略:梯度累积 + Checkpoint 的叠加

# cann-recipes-train/examples/combined_strategy.py

def train_combined(model, dataloader, optimizer, config):
    """联合策略:梯度累积 8 step + checkpoint_every=2"""

    accum_steps = 8
    micro_batch_size = 16

    # 每个 micro-batch 的激活:20 层 checkpoint × 0.7GB = 14GB
    # 加上参数 13GB + 优化器 8GB = 35GB → 安全(56GB HBM)
    optimizer.zero_grad()

    for step, batch in enumerate(dataloader):
        micro_batches = batch.chunk(accum_steps)

        for i, mb in enumerate(micro_batches):
            # 最后一个 micro-batch 同步梯度(AllReduce)
            # 其他 micro-batch 只在本地累积
            with model.no_sync() if i < accum_steps - 1 else contextlib.nullcontext():
                with torch.autocast("npu", dtype=torch.float16):
                    loss = model(mb) / accum_steps

                loss.backward()
                del mb, loss

            torch.npu.empty_cache()

        # 8 步累积后 → 1 次 optimizer.step()
        clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad()

显存分析

LLaMA-13B (40 layers, seq=4096, hidden=5120) on 8× Ascend 910

| 策略 | 激活显存 | 参数+优化器 | HBM 总计 | 通信次数/step | 吞吐 |
|------|---------|-----------|---------|------------|------|
| 无优化 (bs=128) | 28GB | 21GB | 49GB (OK) | 1 | 8,200 t/s |
| 无优化 (bs=256) | 56GB | 21GB | 77GB OOM!| — | — |
| GradAcc=8 (bs=16 ×8) | 7GB | 21GB | 28GB (OK) | 1 | 7,800 t/s |
| Checkpoint (every=2) | 14GB | 21GB | 35GB (OK) | 1 | 6,400 t/s |
| **Checkpoint + GradAcc=8** | 3.5GB | 21GB | 24.5GB (OK) | 1 | 6,100 t/s |
| **Checkpoint + GradAcc=4** | 7GB | 21GB | 28GB (OK) | 1 | 6,800 t/s |

最优:Checkpoint (every=2) + GradAcc=4。吞吐 6,800 tokens/s(比无优化的 8,200 低 17%,但 batch 从 128 扩到 512 的等效 batch,收敛快 4×)。

踩坑一:Checkpoint 放错位置——嵌入层和输出层

embedding 层和 lm_head(logits 层)不能放 checkpoint——这两层是参数的入口和出口,forward 不产生大显存(embedding 只需存 lookup 结果,0.5GB),但 backward 需要梯度。checkpoint 这里→反向时 embedding 被重算→梯度丢失。

# ❌ embedding 放在 checkpoint → 梯度丢失
class BadModel(nn.Module):
    def forward(self, input_ids):
        # embedding 不占大显存,但放 checkpoint 梯度没了
        x = checkpoint.checkpoint(self.embedding, input_ids)  # ← 错误!
        for layer in self.layers:
            x = layer(x)
        return self.lm_head(x)
# → loss 不降,embedding 的梯度 = 0

# ✅ 从第一个 Transformer 层开始 checkpoint
class CorrectModel(nn.Module):
    def forward(self, input_ids):
        x = self.embedding(input_ids)  # ← 正常 forward(不 checkpoint)
        for i, layer in enumerate(self.layers):
            if i % 2 == 0:
                x = layer(x)  # 检查点层
            else:
                x = checkpoint.checkpoint(layer, x)  # 重算层
        return self.lm_head(x)  # ← lm_head 也不 checkpoint

踩坑二:Gradient Checkpoint 和 FSDP 的 no_sync 冲突

FSDP 的 no_sync() 跳过 AllReduce——但 checkpoint 的反向传播走两次 forward(一次重算),梯度的 requires_grad 状态在重算 forward 时不正确 → FSDP 的 no_sync 判断失败 → 重复通信。

# ❌ no_sync + checkpoint → 重复 AllReduce(吞吐降 15%)
with model.no_sync():  # 期望跳过 AllReduce
    loss = model(mb)
    loss.backward()
# → checkpoint 内部重算时,FSDP 误判状态 → 做了 AllReduce
# → 8 个 micro-batch 做了 8 次 AllReduce(期望 1 次)

# ✅ 检查点层不放 no_sync 内,手动管理 FSDP 状态
for i, mb in enumerate(micro_batches):
    loss = model(mb) / accum_steps
    loss.backward()

    if i < accum_steps - 1:
        # 不等 AllReduce,攒着
        for param in model.parameters():
            if param.grad is not None:
                param.grad._accumulating = True  # 标记「还在累积」
    else:
        # 最后一步:做 AllReduce
        for param in model.parameters():
            if param.grad is not None:
                param.grad._accumulating = False

踩坑三:HBM Empty Cache 过多反而降低吞吐

每步 torch.npu.empty_cache() 释放碎片——但执行太频繁(如每个 micro-batch 调一次)→ NPU 驱动反复释放/分配 → 额外 2-3ms 开销 × 8 micro-batches = 16-24ms per step。

# ❌ 每个 micro-batch 都 empty_cache → 20ms 额外开销
for i, mb in enumerate(micro_batches):
    loss = model(mb)
    loss.backward()
    torch.npu.empty_cache()  # ← 每步 2ms → 8 步 = 16ms

# ✅ 只在显存压力大时(最后一步前)调一次
for i, mb in enumerate(micro_batches):
    loss = model(mb)
    loss.backward()

    # 只在 OOM 风险高时清理(显存使用 > 90%)
    if i == accum_steps - 2:  # 倒数第 2 步
        mem_used = torch.npu.memory_allocated() / torch.npu.max_memory_allocated()
        if mem_used > 0.9:
            torch.npu.empty_cache()

# 最后一步后必然清理(# 为 optimizer step 腾出空间)
torch.npu.empty_cache()

梯度累积 + Gradient Checkpoint 联合用:Checkpoint 把激活从 28GB 压到 14GB(每 2 层一个检查点),梯度累积把 batch 切成 4 个 micro-batch(每个 3.5GB 激活)→ HBM 总计 24.5GB。吞吐从 8,200 tokens/s 降到 6,800(17%),但 batch 从 128 扩到 512 等效,收敛加速 4×。通信省了 75%(4 micro-batch 只做 1 次 AllReduce)。三个踩坑:embedding/lm_head 不 checkpoint 否则梯度丢失、checkpoint 反向重算触发 FSDP 误判导致重复 AllReduce、empty_cache 太频繁反噬吞吐(每步 2ms 开销)。

Logo

作为“人工智能6S店”的官方数字引擎,为AI开发者与企业提供一个覆盖软硬件全栈、一站式门户。

更多推荐