02 — Distributed Pretraining (8B → 70B)
Roles: Research Engineer Pretraining (Anthropic, OpenAI, DeepMind, Meta, xAI)
1. Clarifying Questions
- Target model size and token budget? (Chinchilla: ~20 tok/param. So 8B → 160B tok minimum, ideally more.)
- Hardware: H100 / H200 / TPUv5p? How many nodes? Interconnect (NVLink + InfiniBand / TPU ICI)?
- Training duration target? (Days? Weeks?)
- Checkpointing / restart frequency?
- Mixed-precision (BF16 + FP8)?
- Architecture: dense vs MoE?
2. Capacity Estimation
Example: 70B dense model, 1.5T tokens, BF16 + FSDP.
- Params: 70B × 2 bytes = 140 GB (weights)
- Optimizer states (AdamW, BF16 master + FP32 moments): ~12 bytes/param = 840 GB
- Activations (with recompute): scales with batch × seq × layers
- Total memory per "model replica": > 1 TB → MUST be sharded (FSDP/ZeRO-3 or TP)
- Compute: 6 × P × T flops ≈ 6 × 70e9 × 1.5e12 = 6.3e23 flops
- On H100 @ 400 TFLOPS sustained BF16, 45% MFU: 6.3e23 / (400e12 × 0.45) ≈ 3.5M GPU-seconds
- → 1024 H100s for ~40 days, or 4096 H100s for ~10 days
3. Parallelism Plan
| Dim | Strategy | Why |
|---|---|---|
| Data | DDP / FSDP across replicas | Throughput |
| Tensor (TP) | Megatron-style, within node (TP=4 or 8 over NVLink) | Reduce per-GPU memory; avoid cross-node TP (latency!) |
| Pipeline (PP) | 1F1B or interleaved schedules across nodes | Fit 70B+ across nodes |
| Sequence/Context (SP/CP) | Ring attention | Long context (128k+) |
| Expert (EP) | Top-2 routing, capacity factor 1.25 | If MoE |
Composition example (70B dense, 1024 H100, 8/node):
- TP = 4 (within node)
- PP = 4 (across nodes — partitions of layers)
- DP = 64 (replicas) → 4 × 4 × 64 = 1024
- FSDP shards optimizer states across DP ranks
4. Architecture
Coordinator/Scheduler (Slurm / k8s + Volcano)
│
▼
[Job: 1024 H100 nodes, 16 racks, fat-tree IB]
│
├── Rank-0 driver: writes checkpoints, evals, logging
├── Data loader workers (per node): stream from object store
├── Tokenized shards (uint16 .bin) on local NVMe (warmed from S3)
├── Async checkpointing → S3 (fully shaded, every N steps)
└── Telemetry: every step, every rank → Prometheus / W&B / ClearML
5. Deep Dives
5.1 Numerical Stability
- BF16 master, FP32 reduces in optim
- FP8 with per-tensor scaling for fwd matmuls (Hopper TensorCores) — watch for unstable layers (often LM head)
- Loss scaling not needed in BF16
- Gradient clipping at 1.0
- Residual stream variance growth — use careful init (μP if going extreme), QK norm
5.2 Data Loading at Scale
- Shards on S3 (10s of TB tokenized)
- Stripe across NVMe on each node; double-buffer; prefetch 2 batches ahead
- Document deterministic interleaving: hash(epoch, rank, step) → shard
- Resumable: on restart, jump to (epoch, step), each rank deterministically reproduces the same batches
5.3 Checkpointing
- Async save (don't block training step)
- Sharded checkpoint per rank → S3 with manifest
- Periodic full-precision optimizer state checkpoint (every ~hour)
- More frequent weights-only checkpoint (every ~10min) for eval branches
5.4 Failure Recovery
- Hardware: ECC errors, PSU failures, IB link flaps — losses below 1% MTBF/node/day at scale
- Fast restart: training script idempotent on restart; ~5 min to rebuild parallel groups
- Fault detection: NCCL watchdog timeout 30s; bisect bad nodes; isolate and re-run
- Run health checks (GPU burn, NCCL all-reduce) before launch and every Nth restart
5.5 Hyperparameter Plan
- LR schedule: linear warmup → cosine decay or WSD (warmup-stable-decay)
- Batch size: ramp up gradually (start 1M tokens/batch, end 4M)
- Weight decay 0.1, β = (0.9, 0.95), grad clip 1.0
- Sequence length: optionally curriculum (start 4k, ramp to 32k+)
6. Bottlenecks & Scaling
| Bottleneck | Detection | Fix |
|---|---|---|
| Comm-bound (low MFU < 30%) | NCCL takes > 30% of step | Bigger micro-batch, gradient accumulation, FP8, fewer FSDP shards |
| Stragglers (tail node slow) | Step time variance | Identify hot node; NCCL ring vs tree; use tree if interconnect topology helps |
| Data loader stall | GPU util dips between steps | Prefetch deeper; more workers; pin memory; check S3 throttling |
| Checkpoint blocking | Hiccup every N steps | Async save; persistent process |
7. Observability
Per step, log: loss, grad_norm, lr, param_norm, throughput (tok/s), MFU, NCCL time, data-load time. Per hour: eval on a held-out slice; sample generations; loss spikes alert.
8. Cost Model
- 1024 H100 × 40 days × $4/hr ≈ $3.9M.
- Storage (checkpoints + tokens): ~50 TB on S3 ≈ $1k/mo.
- Networking egress on restart: usually negligible (S3 in-region).
9. Tradeoffs
| Choice | Alternative | When |
|---|---|---|
| FSDP | DeepSpeed ZeRO-3 | FSDP is more PyTorch-native; ZeRO has more knobs |
| Megatron-LM | nanotron / torchtitan | Megatron is battle-tested; new stacks easier to modify |
| BF16 + FP8 | Pure BF16 | FP8 once you've convinced yourself the model is stable |
| Dense | MoE | MoE = better tok/$ at training & serving but harder eval/RLHF |
10. Pitch
"70B on 1024 H100 means TP=4 within-node, PP=4 across, DP=64 with FSDP sharding optimizer states. BF16 master with FP8 matmuls for ~1.6× throughput. Async checkpointing every 10min weights-only, hourly full state. Deterministic resumable data loader keyed on (epoch, step, rank). NCCL watchdog catches silent stragglers. Target 45% MFU; alert if we drop below 35%. Total run: ~40 days, ~$4M, on 1.5T tokens."