🛸 Hitchhiker's Guide — Phase 10: Distributed Training & Data Pipelines
Read this if: You can train a 100M model on one GPU and you want to know what changes at 70B on 1024 GPUs. This is where most engineers stop and most senior engineers start. Mastering this material is the single biggest differentiator at the senior+ level, because almost no one outside frontier labs gets hands-on practice — but everyone is asked about it in interviews.
0. The 30-second mental model
You can't train large models on one GPU because (a) the weights don't fit, (b) the optimizer state doesn't fit (2× weights for AdamW), (c) the activations don't fit, and (d) one GPU can't push enough tokens-per-second to finish in your lifetime. Distributed training shards each of these across many GPUs while keeping the gradients mathematically identical to a single-GPU run.
Five fundamental parallelism strategies — most production runs combine several:
| Strategy | What's sharded | Comm pattern | When to use |
|---|---|---|---|
| Data Parallel (DDP) | Nothing; full model replicated; each GPU sees different data | AllReduce of gradients per step | Small models that fit on one GPU |
| FSDP / ZeRO-3 | Weights, gradients, optimizer state | All-gather weights for forward; reduce-scatter grads | Models too big for one GPU but fit in sum-of-GPU-memory |
| Tensor Parallel (TP) | Each weight matrix split across GPUs in the same node | AllReduce per layer | Within a node (NVLink); MLP and attention matmuls |
| Pipeline Parallel (PP) | Different layers on different GPUs | Point-to-point per micro-batch | Across nodes when TP is saturated |
| Sequence / Context Parallel (SP/CP) | Sequence dimension split | Ring attention | Very long contexts (>32k) |
| Expert Parallel (EP) | MoE experts spread across GPUs | All-to-all per layer | MoE models |
Real 70B run example: TP=4 (within node) × PP=4 × DP=64 (FSDP) = 1024 GPUs. Each parallelism axis fixes a specific bottleneck.
By the end of Phase 10 you should:
- Pick the right parallelism strategy for any (model size, GPU count, interconnect) combo.
- Compute Model FLOPs Utilization (MFU) and explain why 30–50% is excellent.
- Implement DDP and FSDP from scratch (or near it) in PyTorch.
- Build the Phase 10 lab: a CommonCrawl → quality filter → MinHash-dedup → tokenize → mix data pipeline.
- Discuss MoE routing, expert parallelism, capacity factor.
- Be able to tell a believable war story about "we hit a NaN at step 28k and here's how we debugged it".
1. Why one GPU isn't enough — the memory math
For a 70B BF16 model:
- Weights: 70 × 2 = 140 GB. Doesn't fit on H100 80GB.
- Gradients (BF16): another 140 GB.
- AdamW state (FP32 m, v): 70 × 8 = 560 GB.
- Activations at batch=8, seq=4096: ~80 GB.
- Total: ~920 GB peak. ≈ 12 H100 80GB worth of memory just for one batch.
For training throughput, you also want hundreds to thousands of GPUs to finish in weeks, not centuries. Hence distributed.
2. Data Parallel (DDP) — the simplest
2.1 The setup
Every GPU has a full copy of the model. Each step:
- Each GPU samples a different micro-batch.
- Forward + backward locally → produces local gradients.
- AllReduce gradients across all GPUs (sum, then divide by
world_sizefor averaging). - Each GPU runs the same optimizer step → identical updated weights.
Mathematically equivalent to a single-GPU run with effective_batch = micro_batch × world_size.
2.2 PyTorch API
torch.distributed.init_process_group(backend="nccl")
model = DistributedDataParallel(model, device_ids=[local_rank])
# train as normal — DDP overlaps the AllReduce with the backward pass automatically
2.3 The bandwidth budget
NCCL AllReduce of B bytes across N GPUs costs ~ 2 (N-1)/N × B bytes per GPU. For a 7B BF16 model: 14 GB of gradients per step. On 8× H100 with NVLink (450 GB/s bidirectional): ~30ms. Across nodes via InfiniBand (200–400 Gb/s): ~250ms+. Communication can dominate — always overlap with compute.
2.4 Limitations
DDP doesn't help with the memory problem. You replicate everything. Useless for models bigger than one GPU.
3. ZeRO and FSDP — sharding everything
3.1 ZeRO insight (Rajbhandari et al., 2020)
DDP redundantly stores 3 things across all N GPUs: optimizer state, gradients, weights. ZeRO shards them:
- ZeRO-1: shard optimizer state. Saves ~8× memory for AdamW (state is 8× weights in FP32).
- ZeRO-2: shard optimizer state + gradients.
- ZeRO-3: shard everything, including weights. (PyTorch's FSDP is functionally ZeRO-3.)
3.2 FSDP forward/backward dance
For each layer's forward:
- All-gather the layer's weights from peers (so each GPU has full layer weights temporarily).
- Compute forward.
- Free the gathered weights (back to the local shard).
For backward:
- All-gather weights again.
- Compute backward.
- Reduce-scatter the gradients (each GPU keeps only its shard).
Memory: each GPU holds 1/N of weights + grads + opt state, plus full activations of layers it's currently using.
3.3 PyTorch FSDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
model = FSDP(
model,
auto_wrap_policy=functools.partial(transformer_auto_wrap_policy, transformer_layer_cls={MyBlock}),
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.float32),
)
Sharding strategies:
FULL_SHARD(ZeRO-3): max memory savings, max comm.SHARD_GRAD_OP(ZeRO-2): less comm, more memory.HYBRID_SHARD: shard within a node, replicate across nodes. Big practical win — uses fast NVLink for the high-bandwidth all-gather and slower IB only for cross-node gradient sync.
3.4 Activation checkpointing
Keep only the layer inputs during forward; recompute the layer's intermediates during backward. ~30% throughput hit, ~5× activation memory savings. Universal in big-model training.
4. Tensor Parallelism (Megatron-style)
4.1 The idea
Split each weight matrix across TP GPUs. Two flavors per matrix:
- Column-parallel (
Y = X W): splitWalong output dim. Each GPU computes a slice ofY. No comm during forward; backward needs an AllReduce on grad-input. - Row-parallel (
Y = X W): splitWalong input dim. Each GPU computes partialY. Forward AllReduce sums them.
For a transformer MLP: column-parallel up-proj (no comm), then row-parallel down-proj (one AllReduce). Symmetric for backward.
For attention: column-parallel QKV (no comm), per-head local attention (heads are independent → free), then row-parallel output projection (AllReduce).
4.2 The cost
Two AllReduces per layer (one in attention, one in MLP). With ~7 GB activation per AllReduce on a 7B model and high concurrency, this requires NVLink-class interconnect. TP is capped at the number of GPUs in one node (8 on H100 servers) — beyond that, IB bandwidth crushes throughput.
4.3 When to use it
- Models too big for FSDP alone (very large activations during forward).
- Helps reduce per-GPU activation memory because each GPU computes only
1/TPof each matmul. - Combine with PP (across nodes) and FSDP (data dim).
5. Pipeline Parallelism
5.1 The setup
Layers 1–L/4 on GPU group 0; L/4+1 to L/2 on group 1; etc. Forward passes through groups; backward in reverse.
5.2 The bubble problem
Naive PP: GPU 1 sits idle while GPU 0 computes the first batch. Then GPU 1 works while GPU 0 idles. Etc. With P pipeline stages, only 1/P of GPUs are working at any moment — terrible utilization.
5.3 Mitigations
- Micro-batching (1F1B schedule): split each macro batch into
Mmicro-batches. Pipeline them. Bubble time =(P-1) micro-batches. Bubble fraction =(P - 1) / M. NeedM ≫ P(e.g., M=64 for P=4). - Interleaved pipeline (Megatron-LM): assign multiple non-contiguous layer chunks per stage. Smaller bubbles.
5.4 When to use it
- Across nodes (slow IB): point-to-point messages between adjacent stages are smaller than TP's AllReduce.
- Combine: TP within node, PP across nodes, DP/FSDP wrapping it all.
6. Sequence / Context Parallelism
For very long contexts (32k+), the sequence dim is the issue: each GPU's attention is O(T²) activation. Split the sequence across GPUs.
Ring Attention (Liu et al., 2023): each GPU holds 1/N of K, V; pass them around in a ring while computing attention. Used by Anthropic for long-context.
7. Expert Parallelism (for MoE)
7.1 MoE quick recap
Mixture of Experts (Shazeer 2017, Switch Transformer Fedus 2021): replace each MLP with E parallel "expert" MLPs and a small router that picks the top-k experts per token (typically k=2). Sparse activation: each token uses only k/E of the params.
Models: GPT-4 (rumored), Mixtral 8×7B (8 experts, top-2), DeepSeek-V3 (256 experts + 1 shared), Qwen-MoE.
7.2 Expert parallelism
Place different experts on different GPUs. Per-layer flow:
- Router decides which expert each token goes to.
- All-to-all: send each token's hidden state to its expert's GPU.
- Each expert runs its MLP locally.
- All-to-all: send results back.
All-to-all is bandwidth-intensive. Capacity factor (typically 1.25): allow each expert to receive up to 1.25 × tokens / E to handle imbalance — overflow is dropped or sent to a backup expert.
7.3 MoE routing problems
- Load balancing: some experts get all the work. Use auxiliary loss penalizing imbalance.
- Token dropping: capacity overflow loses some tokens' contribution. Tune capacity factor.
- Routing instability: training-time route can flip; mitigated by router z-loss or noise.
8. Putting it together — a real recipe
8.1 70B on 1024 H100s
- TP = 4: within each H100 8-GPU node, shard each transformer layer 4 ways (uses 4 GPUs per node; the other 4 used by another TP group? — actually for 8-GPU nodes you'd typically use TP=8 if the model is wide enough).
- PP = 4: split the 80 layers into 4 stages (20 layers each), one per node group.
- DP = 64 (with FSDP HYBRID_SHARD): 1024 / (4 × 4) = 64 data-parallel replicas.
- Effective batch:
micro × DP × grad_accum= e.g., 1 × 64 × 32 = 2048 sequences × 4096 tokens = 8M tokens per step. - Steps for 1.4T tokens: 1.4e12 / 8e6 = 175k steps.
- Wall clock at 50% MFU on 1024 H100s: ~30–40 days.
- Cost at $2/H100-hour: ~$3M.
8.2 Model FLOPs Utilization (MFU)
$$ \text{MFU} = \frac{\text{achieved FLOPs/s}}{\text{peak FLOPs/s}} = \frac{6 N D / T}{N_{\text{GPU}} \cdot \text{peak per GPU}} $$
- 30% MFU: typical for bad config.
- 45% MFU: good, what Llama-3 reported on H100.
- 50%+: excellent.
- Anthropic / OpenAI rumored 55%+ on internal stacks.
If your MFU is 15%, you have a bug or a misconfig — investigate.
9. The data pipeline — Phase 10's lab focus
9.1 The pipeline (9 stages)
- Source: CommonCrawl WARC files, GitHub crawls, books, papers.
- Parse: WARC → text (HTML extraction with trafilatura or readability).
- URL dedup: drop pages already seen.
- Language ID: fasttext
lid.176. Keep target languages. - Quality filter: Gopher rules (Rae et al., 2021) — symbol-to-word ratio, line length distribution, stopword density, repeating n-grams.
- PII scrub: emails, phones, credit card patterns.
- Near-dup: MinHash + LSH (datasketch) at Jaccard ~0.8.
- Toxicity / NSFW filter: classifier (e.g., hate-speech model).
- Tokenize and shard: write uint16/uint32 .bin files, ~1–10GB each.
Then mix: Common Crawl 70%, code 10%, books 5%, papers 5%, Wikipedia 5%, etc. Tune mixing weights with DSIR (Xie 2023) or DoReMi (Xie et al. 2023), or hand-tune via small-scale ablations.
9.2 Lineage tracking
Every doc carries a chain of pre_filter_hash → post_filter_hash → tokenized_shard_id. When you discover a problem (a leaked benchmark, a CVE'd content) you can purge.
9.3 Lab walkthrough (lab-01-data-pipeline)
What you'll build:
parse_wet(path)— yields documents from a CommonCrawl WET file usingwarcio.is_english(text)—fasttextlid.176model.passes_quality(text)— implements Gopher rules: word count thresholds, average word length, symbol ratio, line uniqueness, etc.Deduper—datasketch.MinHashLSHwith threshold 0.8, num_perm=128.tokenize_to_bin(docs, out_path)— usestiktokenGPT-2; writes uint16 little-endian; appends EOT token between docs.
Run it on a few dozen MB of WET data; observe filter ratios (typical: 20–40% retained after all filters). Observe how the Gopher rules catch SEO spam, low-content boilerplate, etc.
10. Debugging at scale — the war stories
10.1 Loss spike at step 28k
Symptoms: BF16 training, loss suddenly 10× higher for one step. Common causes:
- Bad batch (e.g., a single very-long doc with garbage).
- Numerical underflow in attention softmax.
- Bug in attention masking.
Standard response: skip the batch and continue; if recurring, lower LR or add gradient clipping.
10.2 NaN
- Usually FP16 underflow → switch to BF16.
- Or division by zero somewhere (norm of zero vector).
- Or a corrupted checkpoint reload.
10.3 NCCL hang
- One GPU fails or becomes slow → AllReduce times out → entire job hangs.
- NCCL watchdog (env
TORCH_NCCL_BLOCKING_WAIT=1and timeout) detects and aborts. - Health check + restart from latest checkpoint.
10.4 Async checkpointing
Synchronous checkpointing every 1k steps stalls training for ~5 minutes. Async: snapshot weights into pinned-host memory in one fast op, then a background process writes to storage. PyTorch DCP (Distributed Checkpoint) supports this.
10.5 The right defaults
torch.compile(model)— almost always a free 10–30% speedup.- BF16 throughout; FP32 reductions and master weights only.
- Gradient clipping at 1.0.
- Activation checkpointing on every transformer layer.
- AdamW(0.9, 0.95), wd=0.1.
- LR warmup over first 2000 steps; cosine to 10% of peak.
11. References
Required:
- Rajbhandari et al. (2020), ZeRO: Memory Optimizations Toward Training Trillion Parameter Models.
- Rajbhandari et al. (2021), ZeRO-Infinity.
- Shoeybi et al. (2019), Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism.
- Narayanan et al. (2021), Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM.
- Smith et al. (2022), Using DeepSpeed and Megatron to Train Megatron-Turing NLG 530B.
- The PyTorch FSDP tutorial and paper (Zhao et al., 2023).
- Rae et al. (2021), Scaling Language Models: Methods, Analysis & Insights from Training Gopher — appendix has the quality filter rules.
- Penedo et al. (2023), The RefinedWeb Dataset for Falcon LLM.
- Together's RedPajama data card.
- Xie et al. (2023), DoReMi: Optimizing Data Mixtures Speeds Up Language Model Pretraining.
Important:
- Liu et al. (2023), Ring Attention with Blockwise Transformers for Near-Infinite Context.
- Fedus et al. (2021), Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity.
- Lepikhin et al. (2020), GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding.
- Llama-3 tech report.
- DeepSeek-V3 tech report.
- The OPT logbook (Zhang et al. 2022, appendix).
12. Common interview questions on Phase 10 material
- Walk through DDP, FSDP, TP, PP, EP. Pick the right combo for 70B on 1024 H100s.
- Why is TP usually capped at one node?
- Compute the bubble fraction for PP=4, M=16 micro-batches.
- What's MFU and what's a good number?
- Sketch FSDP's forward and backward.
- Why does ZeRO-3 = FSDP save 3× memory vs DDP?
- What's all-to-all and why is MoE routing expensive?
- Compute the AllReduce cost for a 7B BF16 model across 8 GPUs.
- Loss spikes at step 28k — what do you do?
- Walk through a CommonCrawl → tokens pipeline.
- What's MinHash LSH and how is it used for dedup?
- Compare DoReMi and DSIR for data mix optimization.
- How would you implement async checkpointing?
- Your MFU is 18%. What are the top 5 things to check?
- Llama-3 was trained on 15T tokens at 8B params — that's 1900 tokens/param. Why so far past Chinchilla?
13. From solid → exceptional
- Implement DDP from scratch using
torch.distributed.all_reduce. Train a 100M model on 2 GPUs; verify gradient identicality vs single-GPU. - Run a real FSDP experiment on 4× consumer GPUs with a 7B model. Measure memory and throughput vs DDP attempt.
- Implement MinHash LSH (or use
datasketch); dedup a 10GB text corpus; report compression ratio. - Build the Phase 10 lab data pipeline; measure each stage's filter ratio.
- Read the Llama-3 tech report end-to-end; write a one-page summary of every distributed-training decision.
- Read the DeepSeek-V3 tech report; understand its mixture of FP8 + DualPipe + auxiliary-loss-free routing.
- Implement a tiny MoE block with top-2 routing, capacity factor 1.25, load-balancing aux loss.
- Profile a real distributed run with torch.profiler + Nsight; identify where comm overlaps (or doesn't) with compute.
14. Recommended cadence
| Day | Activity |
|---|---|
| Mon | Read ZeRO + Megatron papers |
| Tue | Read FSDP paper + PyTorch tutorial |
| Wed | Lab 01 — build the data pipeline; run on a small WET file |
| Thu | Read RefinedWeb + Gopher data sections; refine quality rules |
| Fri | Implement DDP from scratch on 2 GPUs (or via Colab+Kaggle) |
| Sat | Read Llama-3 tech report; sketch the parallelism layout |
| Sun | Mock interview the 15 questions; whiteboard the parallelism table |