02 — Distributed Pretraining (8B → 70B)

Roles: Research Engineer Pretraining (Anthropic, OpenAI, DeepMind, Meta, xAI)


1. Clarifying Questions

  • Target model size and token budget? (Chinchilla: ~20 tok/param. So 8B → 160B tok minimum, ideally more.)
  • Hardware: H100 / H200 / TPUv5p? How many nodes? Interconnect (NVLink + InfiniBand / TPU ICI)?
  • Training duration target? (Days? Weeks?)
  • Checkpointing / restart frequency?
  • Mixed-precision (BF16 + FP8)?
  • Architecture: dense vs MoE?

2. Capacity Estimation

Example: 70B dense model, 1.5T tokens, BF16 + FSDP.

  • Params: 70B × 2 bytes = 140 GB (weights)
  • Optimizer states (AdamW, BF16 master + FP32 moments): ~12 bytes/param = 840 GB
  • Activations (with recompute): scales with batch × seq × layers
  • Total memory per "model replica": > 1 TB → MUST be sharded (FSDP/ZeRO-3 or TP)
  • Compute: 6 × P × T flops ≈ 6 × 70e9 × 1.5e12 = 6.3e23 flops
  • On H100 @ 400 TFLOPS sustained BF16, 45% MFU: 6.3e23 / (400e12 × 0.45) ≈ 3.5M GPU-seconds
  • → 1024 H100s for ~40 days, or 4096 H100s for ~10 days

3. Parallelism Plan

DimStrategyWhy
DataDDP / FSDP across replicasThroughput
Tensor (TP)Megatron-style, within node (TP=4 or 8 over NVLink)Reduce per-GPU memory; avoid cross-node TP (latency!)
Pipeline (PP)1F1B or interleaved schedules across nodesFit 70B+ across nodes
Sequence/Context (SP/CP)Ring attentionLong context (128k+)
Expert (EP)Top-2 routing, capacity factor 1.25If MoE

Composition example (70B dense, 1024 H100, 8/node):

  • TP = 4 (within node)
  • PP = 4 (across nodes — partitions of layers)
  • DP = 64 (replicas) → 4 × 4 × 64 = 1024
  • FSDP shards optimizer states across DP ranks

4. Architecture

Coordinator/Scheduler (Slurm / k8s + Volcano)
        │
        ▼
    [Job: 1024 H100 nodes, 16 racks, fat-tree IB]
        │
        ├── Rank-0 driver: writes checkpoints, evals, logging
        ├── Data loader workers (per node): stream from object store
        ├── Tokenized shards (uint16 .bin) on local NVMe (warmed from S3)
        ├── Async checkpointing → S3 (fully shaded, every N steps)
        └── Telemetry: every step, every rank → Prometheus / W&B / ClearML

5. Deep Dives

5.1 Numerical Stability

  • BF16 master, FP32 reduces in optim
  • FP8 with per-tensor scaling for fwd matmuls (Hopper TensorCores) — watch for unstable layers (often LM head)
  • Loss scaling not needed in BF16
  • Gradient clipping at 1.0
  • Residual stream variance growth — use careful init (μP if going extreme), QK norm

5.2 Data Loading at Scale

  • Shards on S3 (10s of TB tokenized)
  • Stripe across NVMe on each node; double-buffer; prefetch 2 batches ahead
  • Document deterministic interleaving: hash(epoch, rank, step) → shard
  • Resumable: on restart, jump to (epoch, step), each rank deterministically reproduces the same batches

5.3 Checkpointing

  • Async save (don't block training step)
  • Sharded checkpoint per rank → S3 with manifest
  • Periodic full-precision optimizer state checkpoint (every ~hour)
  • More frequent weights-only checkpoint (every ~10min) for eval branches

5.4 Failure Recovery

  • Hardware: ECC errors, PSU failures, IB link flaps — losses below 1% MTBF/node/day at scale
  • Fast restart: training script idempotent on restart; ~5 min to rebuild parallel groups
  • Fault detection: NCCL watchdog timeout 30s; bisect bad nodes; isolate and re-run
  • Run health checks (GPU burn, NCCL all-reduce) before launch and every Nth restart

5.5 Hyperparameter Plan

  • LR schedule: linear warmup → cosine decay or WSD (warmup-stable-decay)
  • Batch size: ramp up gradually (start 1M tokens/batch, end 4M)
  • Weight decay 0.1, β = (0.9, 0.95), grad clip 1.0
  • Sequence length: optionally curriculum (start 4k, ramp to 32k+)

6. Bottlenecks & Scaling

BottleneckDetectionFix
Comm-bound (low MFU < 30%)NCCL takes > 30% of stepBigger micro-batch, gradient accumulation, FP8, fewer FSDP shards
Stragglers (tail node slow)Step time varianceIdentify hot node; NCCL ring vs tree; use tree if interconnect topology helps
Data loader stallGPU util dips between stepsPrefetch deeper; more workers; pin memory; check S3 throttling
Checkpoint blockingHiccup every N stepsAsync save; persistent process

7. Observability

Per step, log: loss, grad_norm, lr, param_norm, throughput (tok/s), MFU, NCCL time, data-load time. Per hour: eval on a held-out slice; sample generations; loss spikes alert.

8. Cost Model

  • 1024 H100 × 40 days × $4/hr ≈ $3.9M.
  • Storage (checkpoints + tokens): ~50 TB on S3 ≈ $1k/mo.
  • Networking egress on restart: usually negligible (S3 in-region).

9. Tradeoffs

ChoiceAlternativeWhen
FSDPDeepSpeed ZeRO-3FSDP is more PyTorch-native; ZeRO has more knobs
Megatron-LMnanotron / torchtitanMegatron is battle-tested; new stacks easier to modify
BF16 + FP8Pure BF16FP8 once you've convinced yourself the model is stable
DenseMoEMoE = better tok/$ at training & serving but harder eval/RLHF

10. Pitch

"70B on 1024 H100 means TP=4 within-node, PP=4 across, DP=64 with FSDP sharding optimizer states. BF16 master with FP8 matmuls for ~1.6× throughput. Async checkpointing every 10min weights-only, hourly full state. Deterministic resumable data loader keyed on (epoch, step, rank). NCCL watchdog catches silent stragglers. Target 45% MFU; alert if we drop below 35%. Total run: ~40 days, ~$4M, on 1.5T tokens."