PaperReading:DeepSpeed
paperreading
本文字数:2.1k 字 | 阅读时长 ≈ 8 min

PaperReading:DeepSpeed

paperreading
本文字数:2.1k 字 | 阅读时长 ≈ 8 min

1. Deepspeed

DeepSpeed 是微软开源的分布式训练框架,主要用于训练大模型,主要分为三个阶段:

1.1 Zero Stage 0(全量复制)

其实 ZeRO-0 并不是 DeepSpeed 提供的一种优化,而是指传统数据并行(DP)的实现,即所有 GPU 节点上完整存储模型参数、优化器状态和梯度,仅用于对比后续 ZeRO 阶段。

🔎 优缺点:
✅ 实现简单,适合中小模型。
❌ 对显存要求极高,浪费内存资源。
❌ 随着模型参数量增加,内存成为瓶颈(单卡无法容纳模型)。

1.2 Zero Stage 1(优化器状态分片)

将优化器状态分片,只在局部 GPU 存储对应片段,但仍然保留完整的梯度和参数。

🔎 优缺点:
✅ 易于实现,适合中等规模模型。
✅ 内存占用显著降低(约节省 4x~8x 显存,取决于优化器)。
❌ 梯度和参数仍完整复制,显存消耗仍较大。
❌ 不适合超大模型(如 GPT-3),无法彻底解决内存瓶颈。

1.3 Zero Stage 2(再分片梯度)

在 ZeRO-1 的基础上,将梯度分片,显著减少反向传播中的显存消耗。

🔎 优缺点:
✅ 显存占用进一步降低(比 ZeRO-1 再降低约 2x)。
✅ 支持更大模型训练(在内存有限场景下)。
✅ 通信量较低(只在梯度同步时通信)。
❌ 参数仍然完整存储在每张卡上。
❌ 如果模型参数过大,单卡内存仍可能不够。

1.4 Zero Stage 3(再分片参数,完全分布式)

在 ZeRO-2 的基础上,将模型参数分片,使得每张 GPU 仅存储自己负责的参数、梯度和优化器状态。真正实现全局内存分布式。

🔎 优缺点:
✅ 显存占用极低(理论上支持任意大模型训练)。
✅ 可与混合精度、激活重计算(checkpoint)结合进一步优化。
❌ 通信开销显著增加(尤其是 AllGather / ReduceScatter)。
❌ 实现复杂度高,对分布式网络要求较高。

1.5 总结对比表

阶段 优化器状态 梯度 参数 显存节省量 通信开销
Zero-0 全量复制 全量 全量
Zero-1 分片 全量 全量 ~4x
Zero-2 分片 分片 全量 ~8x
Zero-3 分片 分片 分片 最高

2. DeepSpeed ZeRO 各阶段实现详解

2.1 ZeRO-0:全量复制(Data Parallel)

model = MyLargeModel().to(device)
model = DDP(model, device_ids=[rank])

for batch in data_loader:
    optimizer.zero_grad()
    loss = model(batch)
    loss.backward()
    optimizer.step()

通信:每次反向传播后,DDP 会自动用 AllReduce 同步梯度。

2.2 ZeRO-1:优化器状态分片

from deepspeed import initialize

model_engine, optimizer, _, _ = initialize(
    model=model,
    model_parameters=model.parameters(),
    config='ds_config_zero1.json'
)

'''
ds_config_zero1.json 示例:
{
  "train_batch_size": 32,
  "zero_optimization": {
    "stage": 1
  }
}
'''

通信:减少了优化器状态传输,但参数和梯度仍然完整存储和同步。

2.3 ZeRO-2:优化器状态+梯度分片

model_engine, optimizer, _, _ = initialize(
    model=model,
    model_parameters=model.parameters(),
    config='ds_config_zero2.json'
)

核心逻辑:

2.4 ZeRO-3:优化器状态+梯度+参数分片(全分布式)

核心代码:

model_engine, optimizer, _, _ = initialize(
    model=model,
    model_parameters=model.parameters(),
    config='ds_config_zero3.json'
)

示例配置 ds_config_zero3.json:

{
  "train_batch_size": 32,
  "zero_optimization": {
    "stage": 3,
    "offload_param": {
      "device": "cpu"
    },
    "offload_optimizer": {
      "device": "cpu"
    }
  }
}

2.5 关键通信:AllGather & ReduceScatter

模拟 AllGather

param_part = get_local_partition() # 本地分片
full_param = torch.cat(dist.all_gather(param_part), dim=0)

模拟 ReduceScatter

local_grad = compute_local_grad()
global_grad = dist.reduce_scatter(torch.zeros_like(local_grad), local_grad)

2.6 总结关键流程:

阶段 优化器状态 梯度 参数 前向关键通信 反向关键通信
ZeRO-0 全量 全量 全量 AllReduce
ZeRO-1 分片 全量 全量 AllReduce
ZeRO-2 分片 分片 全量 ReduceScatter
ZeRO-3 分片 分片 分片 AllGather ReduceScatter

3. 不同 stage 分片是如何考量的

🎯 为什么 ZeRO 分阶段采用这种顺序?(优化器 ➡ 梯度 ➡ 参数)

🔹 1️⃣ ZeRO-1:先分片优化器状态(Optimizer State Partitioning)

📌 背后考虑:

📊 优先切优化器状态 = 最大内存收益(通常节省 4-8x 内存)+ 最小工程复杂度

🔹 2️⃣ ZeRO-2:再分片梯度(Gradient Partitioning)

📌 背后考虑:

📊 在保留较低工程复杂度的基础上,进一步显著降低内存(再减约 2x)

🔹 3️⃣ ZeRO-3:最后切分参数(Parameter Partitioning)

📌 背后考虑:

📊 尽管参数分片节省显存最多(约 N 倍),但工程复杂度和通信开销也最高,因此最后实现

🏆 整体思路总结:

阶段 优化目标 技术原因 工程考虑
ZeRO-1 优化器状态分片 内存开销大但通信量低,使用简单 易实现,收益明显
ZeRO-2 梯度分片 较大内存开销,需要参与通信(ReduceScatter) 工程复杂度适中
ZeRO-3 参数分片(完全分布) 内存开销最大,通信需求高,前向/反向都需插入通信 工程复杂度最高,最后实现

💡 总结

✅ ZeRO 分阶段是基于 内存收益最大化 + 工程复杂度最小化 + 通信优化平衡
✅ 先从 易实现、收益高 的优化器状态下手(ZeRO-1),再处理涉及反向传播的梯度(ZeRO-2),最后处理需要频繁 AllGather 的参数(ZeRO-3)
✅ 这种循序渐进的方法确保了从传统数据并行平滑过渡到完全分布式(即 ZeRO-3)

5月 06, 2025
4月 06, 2025
ufw