Distributed Training Architecture
Scaling training from 1 GPU to 100s of GPUs — theory, implementation, and tradeoffs.
Why Distributed Training?
| Constraint | Solution |
|---|---|
| Model doesn't fit in 1 GPU | Model parallelism, FSDP |
| Training too slow | Data parallelism (DDP) |
| Both | Hybrid parallelism (3D parallelism) |
Data Parallelism — DDP
Concept: Each GPU holds a full copy of the model. Each step:
- Split the mini-batch across N GPUs (each sees batch_size/N samples)
- Each GPU computes forward + backward independently
- AllReduce gradients across all GPUs (ring-allreduce via NCCL)
- All GPUs update identically → models stay in sync
Key property: DDP is mathematically equivalent to training with a global batch size of N × batch_size_per_gpu. This is why you scale the learning rate: lr = base_lr × N (linear scaling rule, Goyal et al.).
# Launch: torchrun --nproc_per_node=8 train.py
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def main():
dist.init_process_group("nccl") # NCCL for GPU-GPU, gloo for CPU
rank = dist.get_rank() # This process's GPU index (0-7)
local_rank = rank % torch.cuda.device_count()
model = MyModel().to(local_rank)
model = DDP(model, device_ids=[local_rank],
find_unused_parameters=False) # False = faster
# Each rank sees a different shard of data
sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(),
rank=rank, shuffle=True)
loader = DataLoader(dataset, sampler=sampler, batch_size=64,
pin_memory=True, num_workers=4)
for epoch in range(n_epochs):
sampler.set_epoch(epoch) # Required for proper shuffling!
for batch in loader:
# Forward + backward same as single-GPU
loss = model(batch)
loss.backward() # DDP hooks trigger AllReduce here
optimizer.step()
optimizer.zero_grad()
NCCL AllReduce
Ring-allreduce: each GPU sends and receives gradients in a ring topology.
- Communication cost: $2(N-1)/N \times \text{gradient_size}$ — nearly independent of N!
- For N=8 GPUs: 87.5% of gradient data transmitted (vs naive: 7× for a parameter server)
- NVLink bandwidth (A100): 600 GB/s bidirectional → AllReduce of 1GB params in ~1.7ms
Gradient Accumulation
Simulate a larger batch size without more GPU memory:
ACCUMULATE_STEPS = 8 # Effective batch = 8 × per_step_batch
optimizer.zero_grad()
for step, (x, y) in enumerate(loader):
with torch.cuda.amp.autocast():
loss = model(x, y) / ACCUMULATE_STEPS # Normalize loss!
scaler.scale(loss).backward()
# Gradients accumulate in .grad buffers
if (step + 1) % ACCUMULATE_STEPS == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
With DDP: Use model.no_sync() context manager for accumulation steps to avoid expensive AllReduce on every backward — only sync on the last accumulation step:
for i, (x, y) in enumerate(loader):
sync_context = contextlib.nullcontext() if (i+1) % ACCUM == 0 else model.no_sync()
with sync_context:
loss = model(x, y) / ACCUM
loss.backward()
if (i+1) % ACCUM == 0:
optimizer.step(); optimizer.zero_grad()
FSDP — Fully Sharded Data Parallel
For models too large for 1 GPU (ViT-H, LLMs). FSDP shards model parameters, gradients, and optimizer states across GPUs:
DDP (N=4 GPUs):
GPU0: full model copy (10GB) + 10GB gradients + 20GB optim states = 40GB
GPU1: full model copy (10GB) + 10GB gradients + 20GB optim states = 40GB
FSDP (N=4 GPUs):
GPU0: 1/4 of params (2.5GB) + 1/4 gradients (2.5GB) + 1/4 optim (5GB) = 10GB ✅
GPU1: 1/4 of params ...
During forward: GPU0 broadcasts its shard to others → full layer weights
→ runs layer → discards non-owned params
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
mp_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)
model = FSDP(model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mp_policy,
auto_wrap_policy=transformer_auto_wrap_policy, # shard at attention layer
)
3D Parallelism (LLM scale)
Used by GPT-4, Gemini for trillion-parameter models:
Tensor Parallelism (TP)
Split single layer across GPUs
◄─────────────────────────►
┌────┬────┐ ┌────┬────┐
│TP0 │TP1 │ │TP0 │TP1 │ ← Pipeline Stage 0 (layers 1-12)
└────┴────┘ └────┴────┘
┌────┬────┐ ┌────┬────┐
│TP0 │TP1 │ │TP0 │TP1 │ ← Pipeline Stage 1 (layers 13-24)
└────┴────┘ └────┴────┘
▲ ▲
Pipeline Parallelism (PP): stages on different GPU groups
Data Parallelism (DP): entire pipeline replicated for batch throughput
Training Efficiency Tips
Gradient Checkpointing (Activation Checkpointing)
Forward pass stores only a subset of activations; recomputes the rest during backward.
- Memory: 60-70% reduction in activation memory
- Speed: ~30% slower (extra forward passes)
from torch.utils.checkpoint import checkpoint_sequential
# Recompute activations every 4 layers during backward
output = checkpoint_sequential(model.layers, segments=len(model.layers)//4, input=x)
torch.compile (PyTorch 2.0+)
model = torch.compile(model, mode='max-autotune')
# mode options:
# 'default' — balanced (safe, ~20% speedup)
# 'reduce-overhead' — reduces Python overhead (small models)
# 'max-autotune' — profile all kernel configurations (slow compile, fastest runtime)
Communication Overlap
DDP overlaps gradient computation with AllReduce — as soon as a layer's backward is computed, its gradients start being reduced while later layers continue backward. This is automatic in DDP.
Interview Questions
Q: How does DistributedDataParallel achieve linear scaling efficiency?
A: DDP achieves near-linear scaling due to communication-compute overlap and ring-allreduce efficiency. After each layer's backward pass completes, DDP immediately starts AllReducing those gradients while computing gradients for earlier layers — so communication and computation happen in parallel. Ring-allreduce has communication cost roughly independent of the number of GPUs (it grows as 2(N-1)/N × gradient_size). In practice, DDP on 8 A100s with NVLink achieves ~7.5× speedup (93% efficiency) due to NVLink's 600 GB/s bandwidth.
Q: When would you use FSDP over DDP?
A: Use FSDP when the model + optimizer states don't fit on a single GPU. With DDP, each GPU needs: 2 bytes (fp16 param) + 2 bytes (fp16 grad) + 8 bytes (fp32 master weight + Adam states) ≈ 12 bytes/param. A 1B parameter model needs 12GB per GPU — feasible. A 10B model needs 120GB per GPU — impossible even on A100 (80GB). FSDP shards everything across GPUs, so the per-GPU memory is 1/N. The tradeoff: FSDP has higher communication overhead (AllGather before each layer's forward) but that's necessary when you have no choice.
Q: You scale DDP from 1 to 8 GPUs and the training loss curves don't match. Why?
A: Several causes: (1) Learning rate not scaled: with 8× larger effective batch, you need ~2.83× higher LR (sqrt scaling) or linear scaling + warmup. (2) BatchNorm statistics: each GPU computes BN stats on its local data shard (batch/8), leading to noisy stats. Fix: use torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) to synchronize BN across GPUs. (3) DistributedSampler epoch not set: without sampler.set_epoch(epoch), each epoch sees the same data order on each GPU, breaking the i.i.d. assumption.