Lab 05 — Distributed Training
Phase 3: PyTorch | Week 10
When a model doesn't fit on one GPU, or training takes too long, you need distributed training. This is a required skill for any ML engineer working at scale.
Learning Objectives
- Understand DDP (DistributedDataParallel) vs DataParallel vs FSDP
- Implement gradient accumulation that provably matches large-batch training
- Quantify communication overhead: bandwidth, latency, model size tradeoffs
- Understand Amdahl's Law applied to distributed ML
- Write a production-ready DDP launch template
Theory
Data Parallelism — DDP
Each GPU holds a full model copy. Batch is split across GPUs. Gradients are synchronized after each backward pass via All-Reduce.
GPU0: batch_0 → forward → backward → grad_0 ─┐
GPU1: batch_1 → forward → backward → grad_1 ─┤─→ AllReduce → averaged grad → update
GPU2: batch_2 → forward → backward → grad_2 ─┘
Ring-AllReduce (NCCL): each GPU communicates with 2 neighbors in a ring. Total data transferred per GPU: $2 \cdot (N-1)/N \cdot \text{model_size}$. Bandwidth scales with $N$ GPUs.
Gradient Synchronization
DDP uses model.no_sync() context manager to suppress gradient sync for gradient accumulation:
for i, batch in enumerate(loader):
if i % accum_steps == 0:
optimizer.zero_grad()
context = model.no_sync() if (i+1) % accum_steps != 0 else contextlib.nullcontext()
with context:
loss = model(batch) / accum_steps
loss.backward()
if (i+1) % accum_steps == 0:
optimizer.step()
Gradient Accumulation (Single GPU)
Mathematically equivalent to using effective_batch_size = batch_size × accum_steps:
$$\frac{1}{N_{\text{eff}}} \sum_{i=1}^{N_{\text{eff}}} \nabla_\theta L_i = \frac{1}{S} \sum_{s=1}^{S} \left(\frac{1}{N} \sum_{i \in \text{mini-batch}s} \nabla\theta L_i\right)$$
Proof: sum of mini-batch gradients divided by total steps = gradient of the full batch. Divide each mini-batch loss by accum_steps before .backward().
Scaling Efficiency — Amdahl's Law
If fraction $p$ of work is parallelizable:
$$\text{Speedup}(N) = \frac{1}{(1-p) + p/N}$$
For DDP, $1-p$ is communication overhead. With fast interconnects (NVLink ~600 GB/s), $p \approx 0.99$. With slow (PCIe ~50 GB/s), $p \approx 0.9$.
Linear scaling rule (He et al.): when scaling from batch size $B$ to $kB$ with $k$ GPUs, multiply LR by $k$. Requires warmup (5 epochs) for large $k$.
FSDP — Fully Sharded Data Parallel
DDP keeps a full model copy on each GPU. FSDP shards model parameters, gradients, and optimizer state across GPUs:
- Memory per GPU: $\approx \text{model_memory} / N$
- Each GPU only holds $1/N$ of parameters
- Parameters are gathered (all-gather) when needed for forward/backward
When to use: model > 10B params, or when model + optimizer state doesn't fit on a single GPU.
3D Parallelism (Megatron-LM)
| Dimension | What's split | For |
|---|---|---|
| Data parallel | Mini-batch | Fast, always use |
| Tensor parallel | Individual weight matrices | Large FC/attention |
| Pipeline parallel | Model layers across GPUs | Huge models (GPT-3+) |
What the Lab Covers
| Function | Concept |
|---|---|
DDP_TEMPLATE | Production torchrun template |
gradient_accumulation_demo() | Proves equivalence to large batch |
allreduce_overhead_simulation() | Model size vs bandwidth chart |
scaling_efficiency_plot() | Amdahl's law: NVLink/PCIe/InfiniBand |
Interview Questions
Q: DDP vs DataParallel — why always use DDP?
A: DataParallel uses a single Python process with a parameter server on GPU0, creating a bottleneck. DDP uses one process per GPU with NCCL all-reduce — no bottleneck, near-linear scaling.
Q: What is the communication complexity of all-reduce for N GPUs with model size M? A: Ring-all-reduce: $2(N-1)/N \cdot M$ data transferred per GPU. For large $N$, approaches $2M$ per GPU regardless of $N$ — bandwidth efficient.
Q: Gradient accumulation vs larger batch — are they truly equivalent? A: Mathematically yes (if you scale LR accordingly). Practically, there are differences: (1) BatchNorm statistics use the mini-batch, not the effective batch. (2) Data order differs slightly. (3) It's slower per sample. Use it when memory limits batch size.
Q: What is find_unused_parameters=True in DDP and when do you need it?
A: When some model parameters don't receive gradients in every forward pass (e.g., conditional branches), DDP's gradient sync would hang waiting for them. This flag detects and skips unused parameters. It adds overhead — only use when needed.
Run
# Single GPU:
python solution.py
# Multi-GPU with torchrun:
torchrun --nproc_per_node=4 solution.py
# Outputs saved to outputs/