昇腾CANN torchtitan-npu 3D 并行实战:DP+TP+PP 组合策略与 Pipeline Bubble 消除
175B 参数的大模型不能放在一张 NPU 上——需要分布式。三种并行策略各有优劣:数据并行(DP)简单但显存不降、张量并行(TP)通信密集但显存降得最多、流水线并行(PP)显存也降但有 bubble(GPU 空闲等待)。torchtitan-npu 的组合策略:DP 跨节点、TP 跨卡、PP 跨层,三合一最大化吞吐。
三种并行的本质
数据并行 (DP):
每张卡有完整模型的副本,数据分成 N 份
前向独立计算 → AllReduce 梯度 → 更新参数
显存:1× model(不降)
通信:每层 AllReduce 一次梯度(O(model_size))
张量并行 (TP):
每张卡有模型权重的一部分(按列或行切分)
前向需要 AllReduce 中间激活
显存:1/N × model(降最多)
通信:每层 2× AllReduce 激活(O(batch_size × hidden))
流水线并行 (PP):
每张卡有连续的几层(layer 0-7 在卡 0,8-15 在卡 1...)
前向:micro-batch 流水线
显存:1/PP × model(按层降)
通信:层间传递激活(O(batch_size × hidden)),很少
缺点:Pipeline Bubble——微批次之间的空闲等待
组合策略:DP8 × TP4 × PP2
torchtitan-npu 的推荐配置:DP 跨 8 节点、TP 跨 4 卡(节点内)、PP 分 2 段。总共 8×4×2 = 64 张 NPU。
物理分布(64 NPU = 8 节点 × 8 卡/节点):
节点 0 (卡 0-3):Layer 0-15, TP shard 0-3
节点 0 (卡 4-7):Layer 16-31, TP shard 0-3
节点 1 (卡 0-3):Layer 0-15, TP shard 0-3
节点 1 (卡 4-7):Layer 16-31, TP shard 0-3
...
节点 7 (卡 0-3):Layer 0-15, TP shard 0-3
节点 7 (卡 4-7):Layer 16-31, TP shard 0-3
DP 组:8 个节点各有一份完整模型
TP 组:每 4 张卡切分一层
PP 组:Layer 0-15 和 Layer 16-31 分成两段
为什么这样切?TP 的通信是卡间(节点内 NVLink/NVSwitch,800GB/s),DP 的通信是节点间(RoCE/IB 网络,200GB/s)。TP 放节点内(高速)、DP 放节点间(低速)——匹配物理带宽。
Ascend C 实现
第 1 步:TP 的权重切分与激活通信
TP 的核心:把权重矩阵分片存储,前向时每个 TP rank 算自己那部分,然后通信合并。
# torchtitan-npu/parallelism/tensor_parallel.py
class TensorParallelLayer(nn.Module):
"""TP 包装器:自动切分权重并做激活通信"""
def __init__(self, layer, tp_size, tp_rank):
super().__init__()
self.layer = layer
self.tp_size = tp_size
self.tp_rank = tp_rank
def forward(self, x):
# ColumnParallel:输入广播到所有 TP rank
# 输出 = 每 rank 算自己的 weight 部分
# 不需要通信(输入相同 → 各自独立算)
# RowParallel:每 rank 独立输入
# 输出需要 AllReduce(所有 rank 的部分贡献相加)
# 通信量 = hidden_dim × tp_size × batch_size
if self.layer.is_column_parallel:
return self.layer(x) # 不需要通信
else:
local_out = self.layer(x)
# AllReduce 在 NCCL/HCCL 层面执行
return dist.all_reduce(local_out, op=dist.ReduceOp.SUM)
完整的 TP 前向(以 Attention 为例):
# Attention with TP:Q/K/V 投影矩阵按列切分
# Q_proj: [D, D] → [D, D/TP] — 每 rank 存储 1/TP 的列
# 输出:每 rank 算 Q_local = x × Q_proj_local → shape [B, S, D/TP]
def attention_tp(x, tp_size, tp_rank):
# Q/K/V 投影(每 rank 独立算)
Q_local = linear_tp(x, Q_weight_shard[tp_rank]) # [B, S, D_head/TP]
K_local = linear_tp(x, K_weight_shard[tp_rank])
V_local = linear_tp(x, V_weight_shard[tp_rank])
# Attention 计算(每 rank 独立——Q/K/V 都分片,但双线性 QK^T 不跨 TP 通信)
attn_local = scaled_dot_product_attention(Q_local, K_local, V_local)
# O 投影:RowParallel → 输出需要 AllReduce
O_local = linear_tp(attn_local, O_weight_shard[tp_rank]) # [B, S, D]
O = dist.all_reduce(O_local) # ← 唯一一次通信
return O
第 2 步:PP 的 Micro-batch 调度
PP 的问题:卡 0 算完 micro-batch 0 后传给卡 1,卡 1 才能开始——这段时间卡 0 空闲了。调度器用轮转填充减少 bubble。
# torchtitan-npu/parallelism/pipeline_parallel.py
class PipelineScheduler:
"""1F1B (One-Forward-One-Backward) 调度器"""
def __init__(self, num_microbatches, pp_rank, pp_size):
self.num_microbatches = num_microbatches # 微批次数量(如 8)
self.pp_rank = pp_rank
self.pp_size = pp_size
self.warmup_microbatches = pp_size - pp_rank - 1 # 预热期微批次
def schedule(self):
"""返回 (step, is_forward, microbatch_id) 的序列"""
schedule = []
step = 0
# === 阶段 1:Warm-up(前向预热,逐层传递)===
# PP rank 0:F0, F1, F2, F3, ..., F7
# PP rank 1: F0, F1, F2, ..., F7
for mb in range(self.num_microbatches):
for pp in range(self.pp_size):
if mb >= pp:
schedule.append((step, 'forward', mb - pp, pp))
step += 1
# === 阶段 2:1F1B(前向和反向交替)===
# 前向填充全用完后:F7, B0, F8(完), B1, B2, ...
for mb in range(self.num_microbatches - 1, -1, -1):
for pp in range(self.pp_size - 1, -1, -1):
schedule.append((step, 'backward', mb, pp))
step += 1
return schedule
# 示例:PP=2, num_microbatches=4
# 时间轴:
# PP0: F0 F1 F2 F3 B3 B2 B1 B0
# PP1: F0 F1 F2 F3 B3 B2 B1 B0
# ↑ warmup bubble = 1 个 micro-batch 时间
Pipeline Bubble 的计算:
bubble_ratio = (PP - 1) / num_microbatches
PP=2, microbatches=4: bubble = 1/4 = 25%
PP=2, microbatches=8: bubble = 1/8 = 12.5%
PP=4, microbatches=8: bubble = 3/8 = 37.5%
PP=8, microbatches=8: bubble = 7/8 = 87.5%
经验法则:PP 不能太大(≤4),micro-batch 不能太少(≥PP×2)。PP=4 且 mb=8 → bubble=37.5% → 吞吐是 PP=1 的 62.5%。
第 3 步:3D 并行的完整启动脚本
# torchtitan-npu/examples/llama_3d_parallel.py
import torch_npu
from torchtitan_npu.parallelism import (
TensorParallelLayer, PipelineScheduler, FSDPWrapper
)
def setup_3d_parallel():
# DP: 节点间(8 个节点)
dp_size = dist.get_world_size() // (tp_size * pp_size) # 64 / (4*2) = 8
dp_rank = dist.get_rank() % dp_size
# TP: 节点内卡间(4 张卡)
tp_size = 4
tp_rank = (dist.get_rank() // dp_size) % tp_size
# PP: 模型分两段(Layer 0-15, Layer 16-31)
pp_size = 2
pp_rank = (dist.get_rank() // (dp_size * tp_size)) % pp_size
# 创建通信组
dp_group = create_group(dp_rank, dp_size) # AllReduce 梯度
tp_group = create_group(tp_rank, tp_size) # AllReduce 激活
pp_group = create_group(pp_rank, pp_size) # P2P 激活传递
return dp_group, tp_group, pp_group
def create_llama_3d_parallel():
dp_group, tp_group, pp_group = setup_3d_parallel()
# 构建模型:2 段 PP,每段 16 层,每层用 TP
if pp_rank == 0:
layers = [build_llama_layer(i) for i in range(16)]
else:
layers = [build_llama_layer(i) for i in range(16, 32)]
# TP 包装每层
wrapped_layers = [
TensorParallelLayer(layer, tp_size=tp_group.size(), tp_rank=tp_group.rank())
for layer in layers
]
# FSDP 包装(DP 层面)
model = FSDPWrapper(
nn.Sequential(*wrapped_layers),
process_group=dp_group,
sharding_strategy=ShardingStrategy.FULL_SHARD # ZeRO-3: 全部卸载
)
return model
# 训练循环
model = create_llama_3d_parallel()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for step, batch in enumerate(dataloader):
# Pipeline 调度
scheduler = PipelineScheduler(
num_microbatches=8, pp_rank=pp_rank, pp_size=2
)
for step_type, microbatch_id in scheduler.schedule():
if step_type == 'forward':
with torch.npu.amp.autocast(dtype=torch.float16):
loss = model(batch[microbatch_id])
else:
loss.backward()
# DP 的梯度同步(FSDP 自动处理 AllReduce)
optimizer.step()
optimizer.zero_grad()
性能分析
LLaMA-13B on 64× Ascend 910 NPU,hidden=5120, layers=40, seq=2048, bs=64
| 策略 | NPU 数 | 每卡显存 | 吞吐 (tokens/s/NPU) | 通信占比 |
|------|--------|---------|--------------------|---------|
| DP=64 | 64 | OOM | — | — |
| TP=8, DP=8 | 64 | 18.4 GB | 5,120 | 18% |
| PP=8, DP=8 | 64 | 31.2 GB | 3,840 | 5% |
| TP=4, PP=2, DP=8 | 64 | 22.1 GB | 6,850 | 12% |
| TP=4, DP=16 | 64 | OOM | — | — |
最优:TP=4, PP=2, DP=8
吞吐 = 6,850 tokens/s/NPU × 64 NPU = 438,400 tokens/s
单卡显存:22.1 GB(在 32GB HBM 内安全)
通信占比:12%(TP 的激活同步 9% + DP 的梯度同步 3%)
为什么 TP=4,PP=2,DP=8 最优?TP=8 通信太多(18%),PP=8 bubble 太大(吞吐低)。DP=64 OOM 卡放不下。TP/PP/DP 三元组在通信(TP 同步激活)、显存(PP 分片模型)、bubble(PP 空闲)三个维度中找平衡。
踩坑一:TP + PP 的激活通信重复
TP 的 AllReduce 调用在 PP 的每一段都会触发——PP rank 0 和 PP rank 1 各自做自己的 AllReduce。64 张 NPU 下,TP=4 的 AllReduce 组每组只有 4 张卡。但如果 TP 组的边界和 PP 段的边界不一致——就会出现跨段的通信浪费。
# ❌ TP 组跨越 PP 段:
# PP0: Node0[Card0,1,2,3] → 算 Layer 0-15
# PP1: Node1[Card0,1,2,3] → 算 Layer 16-31
# TP 组 = [Node0.Card0, Node0.Card1, Node0.Card2, Node0.Card3]
# → TP AllReduce 在 PP0 完成后做,PP1 也在自己的 TP 组做
# → 重合:PP0 的 TP AllReduce 不是 PP1 需要的
# ✅ TP 组和 PP 段在同一节点内:
# PP0: Node0[Card0,1,2,3] → 算 Layer 0-15 → TP AllReduce (本地 NVSwitch)
# PP1: Node0[Card4,5,6,7] → 算 Layer 16-31 → TP AllReduce (本地 NVSwitch)
# TP 组分别独立,不需要跨节点发送
踩坑二:梯度累积跨越 PP 段的边界
PP 模式下,loss.backward() 只在最后一个 PP rank 上调用(只有它持有实际的 loss)。但 DP 的梯度同步需要所有参数都在——包括前面的 PP 段。
# ❌ DP 的 AllReduce 跨越 PP 段
# PP0 有 Layer 0-15 的梯度
# PP1 有 Layer 16-31 的梯度
# DP AllReduce 组包含 PP0 和 PP1 → 梯度不匹配,有些参数缺梯度
# ✅ FSDP 内部维护了跨 PP 段的梯度通信
# torchtitan-npu 的 FSDPWrapper 处理这个——
# 它只在持有完整层的 PP rank 上做梯度 AllReduce
# PP0 的 Layer 0-15 梯度 → DP AllReduce 在 PP0 内部
# PP1 的 Layer 16-31 梯度 → DP AllReduce 在 PP1 内部
踩坑三:PP 的微批次大小和工作集不匹配
1F1B 调度中,每次前向后马上反向的微批次需要存储前向的中间激活。如果 micro-batch=8 但 tp_size=4→每卡的中间激活 = 8× (hidden/tp) × seq_len × 每层中间结果。
# ❌ micro-batch=16 → 保存 16 份中间激活 → OOM
# 每 micro-batch 的中间激活:
# Attention 的 QKV: 3 × [B/tp, S, d_head] × 2 bytes
# FFN 的中间: [B/tp, S, ffn_dim] × 2 bytes
# total per micro-batch ≈ 500MB (for 13B model, tp=4)
# 16 micro-batches → 8GB → 显存碎片化 → OOM
# ✅ micro-batch=4 → 保存 4 份 → 2GB → 安全
# bubble 变大但显存安全
torchtitan-npu 的 3D 并行不是三个独立策略的和——它们的通信和显存是耦合的。TP 的 AllReduce 在节点内 NVSwitch(800GB/s 快但占带宽)、DP 的 AllReduce 在节点间 RoCE(200GB/s 慢但不频繁)、PP 的 bubble 随 PP 大小线性增长(PP=2 时 12.5% 可接受)。最优配置:DP 跨节点 × TP 跨卡(≤4)× PP 跨层(≤2)——64 张 NPU 下吞吐 438,400 tokens/s。三个关键:TP 组不和 PP 段交叉(减少跨段通信)、FSDP 的梯度同步只在持有完整层的 PP rank 上做、micro-batch 数不能太大(PP=2 时 ≤ 8→显存安全且 bubble=12.5%)。## 混合精度与通信压缩的叠加
3D 并行的 AMP(混合精度)通信加倍了问题:FP16 下每次 TP 的 AllReduce 传输是 FP32 的一半,但梯度需要在 FP32 汇总再做更新——DP 的梯度 AllReduce 还是 FP32。
# torchtitan-npu/amp.py
class MixedPrecision3DParallel:
"""AMP + 3D 并行的通信优化"""
def __init__(self, tp_group, dp_group, pp_group):
self.tp_group = tp_group
self.dp_group = dp_group
self.pp_group = pp_group
self.grad_scaler = torch.npu.amp.GradScaler()
def train_step(self, model, batch, optimizer):
# 前向:FP16(TP 激活 AllReduce = FP16,省一半带宽)
with torch.npu.amp.autocast():
loss = model(batch)
# 反向:梯度在 FP16 下传播
self.grad_scaler.scale(loss).backward()
# 梯度同步(关键优化):TP 的激活梯度已在 FP16 内部同步
# DP 的权重梯度需要 Unscale → FP32 → AllReduce → Scale back
self.grad_scaler.unscale_(optimizer)
# FP32 梯度 AllReduce(DP)
for param in model.parameters():
if param.grad is not None:
# 只在 DP 组内做 AllReduce(不在 TP 组重复做)
dist.all_reduce(param.grad, group=self.dp_group)
param.grad /= self.dp_group.size()
self.grad_scaler.step(optimizer)
self.grad_scaler.update()
AMP + 3D 通信的总结:
通信类型 数据量 精度 频率
TP 激活 AllReduce O(B×S×hidden/tp) FP16 每层 × 2
DP 梯度 AllReduce O(model_params/dp) FP32 每个 step × 1
PP 激活 P2P O(B×S×hidden) FP16 每 micro-batch × 1
总通信量:
TP: 2 × 32 layers × 2 × B×S×D/tp × 2 bytes (FP16)
DP: model_params/dp × 4 bytes (FP32)
PP: num_microbatches × B×S×D × 2 bytes (FP16)
Gradient Checkpoint 与 PP 的互动
Gradient Checkpoint(梯度检查点)省显存但增加计算——和 PP 结合时,checkpoint 的边界必须对齐 PP 的段边界。
# ❌ Checkpoint 边界跨越 PP 段
# PP0: Layer 0-15 → checkpoint(0-15)
# PP1: Layer 16-31 → checkpoint(16-31)
# 反向时 PP0 和 PP1 各自重算 checkpoint 段——
# 但 PP1 需要 PP0 提供的激活→ PP1 不会重算 PP0 的部分
# → 必须等 PP0 重算完再传激活给 PP1 → 延迟翻倍
# ✅ PP 段内部做 checkpoint,不跨越
for pp_rank in range(pp_size):
first_layer = pp_rank * layers_per_pp
last_layer = (pp_rank + 1) * layers_per_pp - 1
for i in range(first_layer, last_layer, checkpoint_every):
# 每 checkpoint_every 层做一个 checkpoint
torch.utils.checkpoint.checkpoint(
lambda x: layers[i:i+checkpoint_every](x),
intermediate_activation
)
HBM 快照与故障恢复
3D 并行中 64 张 NPU 同时训练的故障概率:每张 NPU 的 MTBF(平均故障间隔)≈ 5000 小时 → 64 张 NPU 的联合 MTBF ≈ 78 小时。训练 LLaMA-13B 需要 ≈ 200 小时 → 期望故障 ≈ 2.6 次。HBM 快照机制周期性地保存所有 NPU 的 HBM 状态——包括 TP shard、PP 段、DP 副本。
# torchtitan-npu/snapshot.py
def save_hbm_snapshot(model, optimizer, step):
"""保存所有 NPU 的 HBM 状态到持久化存储"""
snapshot = {
'step': step,
'model_shards': {}, # 每个 TP shard 的完整状态
'optimizer_state': {},# AdamW 的 momentum + variance
'rng_state': {}, # 随机数状态(NPU 的 Philox 状态)
'pp_microbatch_state': {} # PP 中间微批次状态
}
# 每张卡独立保存自己的状态(并行 IO)
for tp_rank in range(tp_size):
for pp_rank in range(pp_size):
key = f'tp{tp_rank}_pp{pp_rank}'
snapshot['model_shards'][key] = model.state_dict(tp_rank, pp_rank)
snapshot['optimizer_state'][key] = optimizer.state_dict()
snapshot['rng_state'][key] = torch.npu.get_rng_state()
# 写入文件(异步,不阻塞训练)
torch.save(snapshot, f'snapshot_step_{step}.pt')
def restore_hbm_snapshot(path):
"""从快照恢复训练"""
snapshot = torch.load(path)
# 恢复模型(每个 TP/PP shard 恢复到对应 NPU)
for tp_rank in range(tp_size):
for pp_rank in range(pp_size):
key = f'tp{tp_rank}_pp{pp_rank}'
model.load_state_dict(snapshot['model_shards'][key],
tp_rank, pp_rank)
optimizer.load_state_dict(snapshot['optimizer_state'][key])
torch.npu.set_rng_state(snapshot['rng_state'][key])
return snapshot['step']
快照频率:每 1000 步一次 → 200 小时训练 × 1000 步/hour = 200 个快照 × 64 cards × 22GB = 280TB 存储。实际实现用增量快照(只存变化的部分),降到 ≈ 5TB。
调试:3D 并行的常见异常定位
def debug_3d_parallel(model, dp_group, tp_group, pp_group):
"""3D 并行的自检工具"""
rank = dist.get_rank()
# 检查 1:TP 的 AllReduce 是否正确
if tp_group:
x = torch.ones(1, device='npu') * rank
dist.all_reduce(x, group=tp_group)
expected = sum(range(tp_group.size()))
assert abs(x.item() - expected) < 1e-5, f'TP all_reduce failed: {x}'
# 检查 2:DP 的参数是否同步
if dp_group:
for name, param in model.named_parameters():
gathered = [torch.zeros_like(param) for _ in range(dp_group.size())]
dist.all_gather(gathered, param, group=dp_group)
# 所有 DP rank 的参数必须相同
for i in range(1, dp_group.size()):
assert torch.allclose(gathered[0], gathered[i]), \
f'DP param mismatch: {name}'
# 检查 3:PP 的激活传递是否正确
if pp_group and rank % pp_group.size() != pp_group.size() - 1:
# 不是最后一个 PP rank → 应该发送激活给下一段
x = torch.randn(1, 4096, device='npu')
dist.send(x, dst=rank + 1, group=pp_group)
# 接收方验证(由下一个 rank 做)
torchtitan-npu 的 3D 并行是一个通信-显存-bubble 三角平衡问题。DP 跨节点(RoCE 200GB/s 慢但每个 step 只同步一次梯度)、TP 跨节点内卡(NVSwitch 800GB/s 快但每层都要同步)、PP 跨层(最小通信但引入 bubble)。LLaMA-13B 的最优配置 TP=4, PP=2, DP=8 在 64 张 NPU 上实现 438,400 tokens/s,单卡内存 22.1GB,通信占比 12%。关键约束:TP 组不和 PP 段交叉(避免跨段通信浪费)、PP 的 micro-batch 数量不能超过显存承载的中间激活上限、梯度检查点边界对齐 PP 段边界、AMP 下 TP 的 FP16 AllReduce 省一半带宽而 DP 的梯度同步保持 FP32 精度。
更多推荐



所有评论(0)