Zero Redundancy OptimizerZeRO 介紹
什麼是 ZeRO?
ZeRO (Zero Redundancy Optimizer) 是 DeepSpeed 的核心技術,目的是減少大型模型在多 GPU 分布式訓練時的記憶體消耗,它將 優化器狀態、梯度、參數 分散到多張 GPU 上。ZeRO 分三個 Stage,每個階段都把「要分散的部分」再往前推進一步。
ZeRO Stage 1: 分散 Optimizer States
- 只把 Optimizer 的狀態(如 Adam 的 m、v)分散到不同 GPU。
- 模型參數和梯度仍然每張 GPU 都有一份完整副本。
- 減少的顯存主要是 Optimizer state,用於模型不算特別大時效果已明顯。
- 缺點:梯度和參數還是全量複製,對非常大模型效果有限。
ZeRO Stage 2: 分散 Optimizer States + Gradients
- 除了 Optimizer state,梯度也分散到各 GPU。
- 減少顯存需求再進一步,因梯度對大模型顯存佔用很可觀。
- 適合百億級參數模型,訓練穩定性與效能仍然很好。
ZeRO Stage 3: 分散 Optimizer States + Gradients + Parameters
- 連模型參數本身都分散到各 GPU。
- 各 GPU 僅持有自己負責的參數分片,訓練時透過通信完成計算。
- 記憶體節省最大,可把單 GPU 記憶體需求降低到幾十分之一。
- 適合超大型模型(上百億甚至千億參數級別)。
- 缺點:由於參數也分散,前向/反向過程需要大量跨 GPU 通訊 → 會影響吞吐量與速度,尤其在跨節點(多機)訓練時。
- 強烈建議配合 gradient checkpointing、混合精度(fp16/bf16)一起用,才能發揮最大效益。
- 如果遇到速度太慢,通常是網路通訊瓶頸,可以嘗試 ZeRO-Offload(把部分狀態放 CPU)、或者優化 NCCL 拓撲。
Stage 比較表
| ZeRO Stage | 分散 Optimizer | 分散梯度 | 分散參數 | 記憶體節省效果 | 適用情況 |
|---|---|---|---|---|---|
| Stage 1 | ✅ | ❌ | ❌ | ⭐ | 10B 參數內模型 |
| Stage 2 | ✅ | ✅ | ❌ | ⭐⭐ | 10B–50B 模型 |
| Stage 3 | ✅ | ✅ | ✅ | ⭐⭐⭐ | 50B 以上超大模型 |
ds_config.json 範例
train_batch_size:全域批次大小(= micro_batch_size_per_gpu × GPU 數 × gradient accumulation steps)train_micro_batch_size_per_gpu:每張 GPU 單次前向的 mini-batch 大小fp16.enabled:啟用混合精度zero_optimization.stage:ZeRO 階段設定,1/2/3 都可offload_optimizer.device:可用"cpu"減少 GPU 顯存壓力(ZeRO-Offload)overlap_comm:開啟通訊/計算重疊,通常能提高速度
{
"train_batch_size": 32,
"train_micro_batch_size_per_gpu": 4,
"steps_per_print": 100,
"optimizer": {
"type": "AdamW",
"params": {
"lr": 2e-5,
"betas": [0.9, 0.999],
"eps": 1e-8,
"weight_decay": 0.01
}
},
"fp16": {
"enabled": true,
"loss_scale_window": 100
},
"zero_optimization": {
"stage": 2, // ⚠️ 這裡改成 1、2、3 切換 ZeRO 階段
"offload_optimizer": {
"device": "none" // 改成 "cpu" 可啟用 ZeRO-Offload
},
"offload_param": {
"device": "none"
},
"contiguous_gradients": true,
"overlap_comm": true,
"reduce_bucket_size": 5e8,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_param_persistence_threshold": 1e6
},
"gradient_clipping": 1.0,
"wall_clock_breakdown": false
}
接下來在 Python 訓練腳本中寫入
model_engine, optimizer, _, _ = deepspeed.initialize(
model=model,
model_parameters=model.parameters(),
config='ds_config.json'
)
