Distributed Training Architecture

Scaling training from 1 GPU to 100s of GPUs — theory, implementation, and tradeoffs.


Why Distributed Training?

ConstraintSolution
Model doesn't fit in 1 GPUModel parallelism, FSDP
Training too slowData parallelism (DDP)
BothHybrid parallelism (3D parallelism)

Data Parallelism — DDP

Concept: Each GPU holds a full copy of the model. Each step:

  1. Split the mini-batch across N GPUs (each sees batch_size/N samples)
  2. Each GPU computes forward + backward independently
  3. AllReduce gradients across all GPUs (ring-allreduce via NCCL)
  4. All GPUs update identically → models stay in sync

Key property: DDP is mathematically equivalent to training with a global batch size of N × batch_size_per_gpu. This is why you scale the learning rate: lr = base_lr × N (linear scaling rule, Goyal et al.).

# Launch: torchrun --nproc_per_node=8 train.py
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def main():
    dist.init_process_group("nccl")  # NCCL for GPU-GPU, gloo for CPU
    rank = dist.get_rank()           # This process's GPU index (0-7)
    local_rank = rank % torch.cuda.device_count()
    
    model = MyModel().to(local_rank)
    model = DDP(model, device_ids=[local_rank],
                find_unused_parameters=False)  # False = faster

    # Each rank sees a different shard of data
    sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(),
                                  rank=rank, shuffle=True)
    loader = DataLoader(dataset, sampler=sampler, batch_size=64,
                        pin_memory=True, num_workers=4)

    for epoch in range(n_epochs):
        sampler.set_epoch(epoch)  # Required for proper shuffling!
        for batch in loader:
            # Forward + backward same as single-GPU
            loss = model(batch)
            loss.backward()  # DDP hooks trigger AllReduce here
            optimizer.step()
            optimizer.zero_grad()

NCCL AllReduce

Ring-allreduce: each GPU sends and receives gradients in a ring topology.

  • Communication cost: $2(N-1)/N \times \text{gradient_size}$ — nearly independent of N!
  • For N=8 GPUs: 87.5% of gradient data transmitted (vs naive: 7× for a parameter server)
  • NVLink bandwidth (A100): 600 GB/s bidirectional → AllReduce of 1GB params in ~1.7ms

Gradient Accumulation

Simulate a larger batch size without more GPU memory:

ACCUMULATE_STEPS = 8  # Effective batch = 8 × per_step_batch
optimizer.zero_grad()

for step, (x, y) in enumerate(loader):
    with torch.cuda.amp.autocast():
        loss = model(x, y) / ACCUMULATE_STEPS  # Normalize loss!
    scaler.scale(loss).backward()
    # Gradients accumulate in .grad buffers

    if (step + 1) % ACCUMULATE_STEPS == 0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

With DDP: Use model.no_sync() context manager for accumulation steps to avoid expensive AllReduce on every backward — only sync on the last accumulation step:

for i, (x, y) in enumerate(loader):
    sync_context = contextlib.nullcontext() if (i+1) % ACCUM == 0 else model.no_sync()
    with sync_context:
        loss = model(x, y) / ACCUM
        loss.backward()
    if (i+1) % ACCUM == 0:
        optimizer.step(); optimizer.zero_grad()

FSDP — Fully Sharded Data Parallel

For models too large for 1 GPU (ViT-H, LLMs). FSDP shards model parameters, gradients, and optimizer states across GPUs:

DDP (N=4 GPUs):
  GPU0: full model copy (10GB) + 10GB gradients + 20GB optim states = 40GB
  GPU1: full model copy (10GB) + 10GB gradients + 20GB optim states = 40GB
  
FSDP (N=4 GPUs):
  GPU0: 1/4 of params (2.5GB) + 1/4 gradients (2.5GB) + 1/4 optim (5GB) = 10GB ✅
  GPU1: 1/4 of params ...
  
  During forward: GPU0 broadcasts its shard to others → full layer weights
                  → runs layer → discards non-owned params
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy

mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.bfloat16,
    buffer_dtype=torch.bfloat16,
)

model = FSDP(model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    mixed_precision=mp_policy,
    auto_wrap_policy=transformer_auto_wrap_policy,  # shard at attention layer
)

3D Parallelism (LLM scale)

Used by GPT-4, Gemini for trillion-parameter models:

         Tensor Parallelism (TP)
         Split single layer across GPUs
         ◄─────────────────────────►
    ┌────┬────┐   ┌────┬────┐
    │TP0 │TP1 │   │TP0 │TP1 │   ← Pipeline Stage 0 (layers 1-12)
    └────┴────┘   └────┴────┘
    ┌────┬────┐   ┌────┬────┐
    │TP0 │TP1 │   │TP0 │TP1 │   ← Pipeline Stage 1 (layers 13-24)
    └────┴────┘   └────┴────┘
         ▲                 ▲
    Pipeline Parallelism (PP): stages on different GPU groups
    Data Parallelism (DP): entire pipeline replicated for batch throughput

Training Efficiency Tips

Gradient Checkpointing (Activation Checkpointing)

Forward pass stores only a subset of activations; recomputes the rest during backward.

  • Memory: 60-70% reduction in activation memory
  • Speed: ~30% slower (extra forward passes)
from torch.utils.checkpoint import checkpoint_sequential
# Recompute activations every 4 layers during backward
output = checkpoint_sequential(model.layers, segments=len(model.layers)//4, input=x)

torch.compile (PyTorch 2.0+)

model = torch.compile(model, mode='max-autotune')
# mode options:
# 'default'       — balanced (safe, ~20% speedup)
# 'reduce-overhead' — reduces Python overhead (small models)
# 'max-autotune'   — profile all kernel configurations (slow compile, fastest runtime)

Communication Overlap

DDP overlaps gradient computation with AllReduce — as soon as a layer's backward is computed, its gradients start being reduced while later layers continue backward. This is automatic in DDP.


Interview Questions

Q: How does DistributedDataParallel achieve linear scaling efficiency?

A: DDP achieves near-linear scaling due to communication-compute overlap and ring-allreduce efficiency. After each layer's backward pass completes, DDP immediately starts AllReducing those gradients while computing gradients for earlier layers — so communication and computation happen in parallel. Ring-allreduce has communication cost roughly independent of the number of GPUs (it grows as 2(N-1)/N × gradient_size). In practice, DDP on 8 A100s with NVLink achieves ~7.5× speedup (93% efficiency) due to NVLink's 600 GB/s bandwidth.

Q: When would you use FSDP over DDP?

A: Use FSDP when the model + optimizer states don't fit on a single GPU. With DDP, each GPU needs: 2 bytes (fp16 param) + 2 bytes (fp16 grad) + 8 bytes (fp32 master weight + Adam states) ≈ 12 bytes/param. A 1B parameter model needs 12GB per GPU — feasible. A 10B model needs 120GB per GPU — impossible even on A100 (80GB). FSDP shards everything across GPUs, so the per-GPU memory is 1/N. The tradeoff: FSDP has higher communication overhead (AllGather before each layer's forward) but that's necessary when you have no choice.

Q: You scale DDP from 1 to 8 GPUs and the training loss curves don't match. Why?

A: Several causes: (1) Learning rate not scaled: with 8× larger effective batch, you need ~2.83× higher LR (sqrt scaling) or linear scaling + warmup. (2) BatchNorm statistics: each GPU computes BN stats on its local data shard (batch/8), leading to noisy stats. Fix: use torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) to synchronize BN across GPUs. (3) DistributedSampler epoch not set: without sampler.set_epoch(epoch), each epoch sees the same data order on each GPU, breaking the i.i.d. assumption.