Zero Redundancy OptimizerZeRO 介紹
什麼是 ZeRO?
ZeRO (Zero Redundancy Optimizer) 是 DeepSpeed 的核心技術,目的是減少大型模型在多 GPU 分布式訓練時的記憶體消耗,它將 優化器狀態、梯度、參數 分散到多張 GPU 上。ZeRO 分三個 Stage,每個階段都把「要分散的部分」再往前推進一步。
ZeRO Stage 1: 分散 Optimizer States
- 只把 Optimizer 的狀態(如 Adam 的 m、v)分散到不同 GPU。
- 模型參數和梯度仍然每張 GPU 都有一份完整副本。
- 減少的顯存主要是 Optimizer state,用於模型不算特別大時效果已明顯。
- 缺點:梯度和參數還是全量複製,對非常大模型效果有限。
ZeRO Stage 2: 分散 Optimizer States + Gradients
- 除了 Optimizer state,梯度也分散到各 GPU。
- 減少顯存需求再進一步,因梯度對大模型顯存佔用很可觀。
- 適合百億級參數模型,訓練穩定性與效能仍然很好。
ZeRO Stage 3: 分散 Optimizer States + Gradients + Parameters
- 連模型參數本身都分散到各 GPU。
- 各 GPU 僅持有自己負責的參數分片,訓練時透過通信完成計算。
- 記憶體節省最大,可把單 GPU 記憶體需求降低到幾十分之一。
- 適合超大型模型(上百億甚至千億參數級別)。
- 缺點:由於參數也分散,前向/反向過程需要大量跨 GPU 通訊 → 會影響吞吐量與速度,尤其在跨節點(多機)訓練時。
- 強烈建議配合 gradient checkpointing、混合精度(fp16/bf16)一起用,才能發揮最大效益。
- 如果遇到速度太慢,通常是網路通訊瓶頸,可以嘗試 ZeRO-Offload(把部分狀態放 CPU)、或者優化 NCCL 拓撲。
Stage 比較表
ZeRO Stage | 分散 Optimizer | 分散梯度 | 分散參數 | 記憶體節省效果 | 適用情況 |
---|---|---|---|---|---|
Stage 1 | ✅ | ❌ | ❌ | ⭐ | 10B 參數內模型 |
Stage 2 | ✅ | ✅ | ❌ | ⭐⭐ | 10B–50B 模型 |
Stage 3 | ✅ | ✅ | ✅ | ⭐⭐⭐ | 50B 以上超大模型 |
ds_config.json
範例
train_batch_size
:全域批次大小(= micro_batch_size_per_gpu × GPU 數 × gradient accumulation steps)train_micro_batch_size_per_gpu
:每張 GPU 單次前向的 mini-batch 大小fp16.enabled
:啟用混合精度zero_optimization.stage
:ZeRO 階段設定,1/2/3 都可offload_optimizer.device
:可用"cpu"
減少 GPU 顯存壓力(ZeRO-Offload)overlap_comm
:開啟通訊/計算重疊,通常能提高速度
{ "train_batch_size": 32, "train_micro_batch_size_per_gpu": 4, "steps_per_print": 100, "optimizer": { "type": "AdamW", "params": { "lr": 2e-5, "betas": [0.9, 0.999], "eps": 1e-8, "weight_decay": 0.01 } }, "fp16": { "enabled": true, "loss_scale_window": 100 }, "zero_optimization": { "stage": 2, // ⚠️ 這裡改成 1、2、3 切換 ZeRO 階段 "offload_optimizer": { "device": "none" // 改成 "cpu" 可啟用 ZeRO-Offload }, "offload_param": { "device": "none" }, "contiguous_gradients": true, "overlap_comm": true, "reduce_bucket_size": 5e8, "stage3_max_live_parameters": 1e9, "stage3_max_reuse_distance": 1e9, "stage3_param_persistence_threshold": 1e6 }, "gradient_clipping": 1.0, "wall_clock_breakdown": false }
接下來在 Python 訓練腳本中寫入
model_engine, optimizer, _, _ = deepspeed.initialize( model=model, model_parameters=model.parameters(), config='ds_config.json' )