03 — Systems Questions
Performance, parallelism, memory, profiling — the gritty side asked in LLM Infra / Inference / Pretraining interviews.
A. Memory & Throughput
Q. How much GPU memory does Llama-3-8B need to serve at 8k context, batch=8, BF16?
- Weights: 8B × 2 bytes = 16 GB
- KV cache per request:
2 × n_layers × n_kv_heads × d_head × seq × bytes- Llama-3-8B: 32 layers, 8 KV heads (GQA), 128 d_head, BF16 = 2 bytes
- = 2 × 32 × 8 × 128 × 8192 × 2 ≈ 1.07 GB / request
- Batch=8 → 8.5 GB KV
- Total: 16 + 8.5 + ~2 GB activations + framework overhead ≈ ~28 GB → fits A100 40GB easily, comfortable on H100 80GB
Q. Why does throughput plateau even when GPU util is 100%?
You're memory-bandwidth bound, not compute bound. Decode-time matmuls have low arithmetic intensity (tokens / weights_bytes_loaded). Fix: bigger batch (more arithmetic per byte loaded), quantize weights (less bytes loaded), speculative decoding (more useful tokens per matmul).
Q. Roofline analysis: which side of the roofline is your kernel on?
Plot arithmetic intensity (FLOP/byte) vs achieved FLOPs. Below the slope = bandwidth-bound; on the flat = compute-bound. Decode is bandwidth-bound, prefill is compute-bound. Different optimizations for each.
B. Parallelism
Q. When would you use TP vs PP vs FSDP?
| Need | Choice |
|---|---|
| Reduce memory across DP replicas | FSDP / ZeRO-3 |
| Model too big for one GPU | TP (within node) |
| Model too big for one node | PP (across nodes) |
| Long context (>128k) | Sequence/Context parallelism |
| MoE | Expert parallelism |
Real systems combine all of these. TP intra-node (NVLink), PP inter-node, FSDP for the data-parallel dim.
Q. Why is TP usually capped at the node size?
TP requires an all-reduce after each attention/MLP block. That's ~2 collectives per layer × N layers per step. Within-node NVLink (~600 GB/s) keeps it fast; cross-node InfiniBand (~25 GB/s effective per GPU) makes it 10× slower → kills throughput.
Q. What's the bubble in pipeline parallelism, and how do you reduce it?
Naive PP: stage 0 idles while stages 1..N-1 work, and vice-versa. Bubble fraction ≈ (P-1)/M where P=pipeline depth, M=number of micro-batches.
Fix: more micro-batches (M >> P); 1F1B scheduling; interleaved 1F1B (Megatron) splits each stage into chunks for finer interleaving.
C. Numerical Precision
Q. Why does pretraining use BF16 master with FP32 reduces?
- BF16 has the same exponent range as FP32 → no need for loss scaling (unlike FP16).
- But BF16 mantissa is small → accumulating many small grads loses precision.
- Solution: do the
all_reduceand optimizer-state updates in FP32; activations and gradients in BF16.
Q. Where does FP8 break?
- Layers with high dynamic range (LM head logits, sometimes embeddings) — quantize aggressively or keep in BF16.
- Outliers in activations (post-LayerNorm spikes) — use per-tensor delayed scaling (Hopper transformer-engine).
- Low-rank adapters — LoRA matrices often need BF16 to converge.
D. Profiling Workflow
- PyTorch Profiler / Nsight Systems: see what fraction of step time is comm vs compute vs data load.
- Idle bubble check: GPU util dipping between steps = data loader is too slow. Increase workers, prefetch, pin memory.
- NCCL tracing: bad allreduce → check ring vs tree topology, MTU, GPUDirect RDMA.
- Memory profiling:
torch.cuda.memory_summary()between steps; look for fragmentation, leaks (often from caching one-off tensors in eval). - Per-op timing: identify the top 3 ops by time; optimize or fuse.
E. Common Bugs
- NaN losses early in training: usually grad explosion in attention (no QK norm) or bad init. Add grad clipping, lower LR, check for fp16 overflow.
- Loss spikes during stable training: data shard with garbage; NaN in a single example; outlier batch with very long sequences.
- OOM only sometimes: variable sequence length pushing peak; bucket by length or set max_seq_len.
- Slow first iteration: kernel autotune (cudnn benchmark mode); compile cache cold. Warm up.
- Throughput dropping over time: memory fragmentation; defrag via
torch.cuda.empty_cache()(but not as a routine).
F. Performance Wins to Reach For
- Use
torch.compile(PyTorch 2.x) — often 1.3-2× free. - FlashAttention-2/3 if available.
- Fused optim (
torch.optim.AdamW(fused=True)). bf16instead of fp32.- Gradient checkpointing only when memory-constrained (it costs ~30% throughput).
- Larger batch → grad accum tradeoff: bigger batch is faster only if it fits.
- Avoid host↔device sync points (
.item(),.cpu(), prints) inside hot loop.