Lab 02 — nanoGPT on TinyStories (Solution Walkthrough)

Phase: 5 — Training Small LLMs | Difficulty: ⭐⭐⭐⭐☆ | Time: 4–8 hours (incl. training)

Reuses the model from ../../phase-04-attention-transformers/lab-04-mini-transformer/solution.py. Concept primer: ../HITCHHIKERS-GUIDE.md §Pretraining mechanics.

Run

pip install -r requirements.txt
python solution.py --prepare        # tokenizes → ./data/train.bin, val.bin
python solution.py --train --steps 2000
python solution.py --sample --prompt "Once upon a time"

0. The mission

Go from raw text to a generating model in one script. End-to-end:

  1. --prepare — download TinyStories, tokenize with tiktoken GPT-2 BPE, write packed uint16 shards.
  2. --train — mixed-precision (BF16) training with gradient accumulation, cosine LR, AdamW, periodic eval + checkpoints.
  3. --sample — load checkpoint, generate text from a prompt.

Default config trains a ~10M-param model in ~30 minutes on a T4 (Colab free) and produces grammatical English. Scale up to d=512, 8 layers and you have a real (if tiny) language model.


1. --prepare — the data pipeline

import tiktoken
enc = tiktoken.get_encoding("gpt2")
ids = enc.encode_ordinary(text)        # ~100M tokens for TinyStories
ids.append(enc.eot_token)              # "<|endoftext|>" id 50256 between docs
arr = np.array(ids, dtype=np.uint16)   # 50257 < 65536 → fits in uint16
arr.tofile(out_dir / "train.bin")
  • encode_ordinary strips special tokens — we don't want stray <|endoftext|> tokens accidentally appearing inside docs.
  • uint16 halves disk footprint vs int32. Required because GPT-2 vocab is 50257 < 65536.
  • EOT between docs so the model learns where stories end. During training we randomly slice across boundaries — the EOT token is the only signal.
  • We write train.bin and val.bin (90/10 split). Loading is np.memmap(...) so a 100 MB file uses zero RAM.
def get_batch(split, block_size, batch_size):
    data = np.memmap(out_dir / f"{split}.bin", dtype=np.uint16, mode="r")
    ix = np.random.randint(0, len(data) - block_size - 1, (batch_size,))
    x = np.stack([data[i:i+block_size].astype(np.int64) for i in ix])
    y = np.stack([data[i+1:i+1+block_size].astype(np.int64) for i in ix])
    return torch.from_numpy(x).to(device), torch.from_numpy(y).to(device)

Random-offset slicing is the standard trick: every batch is a fresh random crop. No shuffling overhead. The model sees ~steps * batch * block_size tokens total; for 2000 steps × batch 64 × block 256 ≈ 33M tokens (1/3 epoch over TinyStories).


2. --train — the training loop

2.1 Optimizer setup

def configure_optimizer(model, lr, weight_decay):
    decay, no_decay = [], []
    for n, p in model.named_parameters():
        if p.dim() >= 2:
            decay.append(p)            # weight matrices, embeddings
        else:
            no_decay.append(p)         # biases, LayerNorm gain/beta
    groups = [
        {"params": decay, "weight_decay": weight_decay},
        {"params": no_decay, "weight_decay": 0.0},
    ]
    return torch.optim.AdamW(groups, lr=lr, betas=(0.9, 0.95), fused=True)

Three non-obvious choices:

  1. No weight decay on 1D parameters. Decaying LayerNorm gains pulls them toward 0, distorting the normalization. Decaying biases is similarly harmful and pointless. Standard since GPT-2.
  2. betas=(0.9, 0.95) — Llama/GPT-3's choice. Default is (0.9, 0.999). The lower β₂ makes the second-moment estimate more responsive to recent gradients — crucial when LR is high and gradient stats change quickly.
  3. fused=True — PyTorch 2.x fused AdamW kernel. ~30% faster on GPU. Only works on CUDA.

2.2 Cosine LR schedule with warmup

def get_lr(step, warmup, max_steps, lr_max, lr_min):
    if step < warmup:
        return lr_max * step / warmup
    if step > max_steps:
        return lr_min
    decay_ratio = (step - warmup) / (max_steps - warmup)
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return lr_min + coeff * (lr_max - lr_min)
  • Warmup — 100–2000 steps. Without it, the first big update from random init explodes activations; AdamW's second-moment estimate is also unreliable until enough gradients accumulate. Skipping warmup is the #1 cause of NaN losses.
  • Cosine decay to lr_min = 0.1 * lr_max. Empirically beats linear, exponential, or step decay.
  • LR is set per-step via for g in opt.param_groups: g["lr"] = lr.

2.3 Mixed precision + gradient accumulation

scaler = torch.cuda.amp.GradScaler(enabled=(dtype == torch.float16))
ctx = torch.amp.autocast("cuda", dtype=dtype)

for micro in range(grad_accum_steps):
    x, y = get_batch("train", block_size, batch_size)
    with ctx:
        _, loss = model(x, y)
        loss = loss / grad_accum_steps
    scaler.scale(loss).backward()
scaler.unscale_(opt)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(opt)
scaler.update()
opt.zero_grad(set_to_none=True)
  • BF16 preferred over FP16 when your GPU supports it (Ampere+). Same dynamic range as FP32; no GradScaler needed (enabled=False).
  • Grad accumulation simulates a larger batch: with grad_accum=8, the effective batch is batch_size * 8 * world_size. Loss divided by grad_accum_steps so the gradient magnitude matches the full batch.
  • clip_grad_norm_(..., 1.0) prevents occasional spikes from corrupting the running optimizer state.
  • zero_grad(set_to_none=True) is faster than zero_grad() (avoids touching every param).

2.4 Periodic eval + checkpoint

if step % eval_interval == 0:
    model.eval()
    with torch.no_grad():
        losses = []
        for _ in range(eval_iters):
            xv, yv = get_batch("val", block_size, batch_size)
            with ctx:
                _, loss = model(xv, yv)
            losses.append(loss.item())
        val_loss = sum(losses) / len(losses)
    model.train()
    if val_loss < best_val:
        best_val = val_loss
        torch.save({"model": model.state_dict(), "opt": opt.state_dict(),
                    "step": step, "val_loss": val_loss}, ckpt_path)

Saving optimizer state allows resume. Saving only the best-val checkpoint avoids disk bloat. For real runs, also save a last checkpoint every N steps for crash recovery.


3. --sample — generation

Load checkpoint, tokenize prompt with tiktoken, call model.generate(...) from Phase 4. Use top_k=200, temperature=0.8 for stories (slightly conservative).


4. Expected output

Default config (d=128, 6 layers, 8 heads, block=256), 2000 steps, T4 GPU:

step    0  loss=10.4321  lr=0.000e+00  ms/step=  N/A
step  100  loss= 5.1234  lr=2.99e-04  ms/step= 280
step  500  loss= 3.4567  lr=5.95e-04  ms/step= 282
step 1000  loss= 2.8902  lr=4.50e-04  ms/step= 281
step 2000  loss= 2.4521  lr=6.00e-05  ms/step= 280
val_loss=2.41 (best, saved)

[sample] Once upon a time, there was a little girl named Lily.
She loved to play with her toys. One day, she found a big box.

Sanity numbers:

  • Initial loss ≈ log(50257) ≈ 10.83. ✅
  • Final val loss for 10M params on TinyStories: ~2.3–2.5 (scales like Chinchilla predicts).
  • 280 ms/step on T4 is normal; 90 ms/step on a 4090.

5. Diagnosing training pathologies

SymptomLikely cause
Loss = NaN at step ~10No warmup, or LR too high. Drop LR 10× or add warmup.
Loss flat at ≈ log(V) for hundreds of stepsLR way too low, or model bug (no gradient flow).
Loss decreases then explodes at step ~1000Forgot grad clipping, or bad init scale.
Train loss ≪ val loss after few stepsOverfitting; reduce model size or add dropout.
Train loss == val loss but highUnderfitting; increase model size or steps.
Loss decreases on train but val plateaus highData quality issue or distribution mismatch.

6. Common pitfalls

  1. Running --prepare every time — cache the .bin files; tokenization is slow.
  2. Forgetting device_type in autocast on CPU — BF16 autocast on CPU only works in PyTorch 2.0+.
  3. memmap on a remote/Network file — random access is brutal on NFS. Copy to local SSD.
  4. torch.compile(model) can help but breaks eager debugging — enable last.
  5. Checkpoint with model.state_dict() only — lose optimizer state → can't resume cleanly.

7. Stretch exercises

  • Scale up to d=512, 8 layers, block=512. ~30M params, ~4 hours on a single A100. Val loss should reach ~1.9.
  • Replace LayerNorm with RMSNorm — ~10% speedup, no quality loss.
  • Add RoPE (rotary position embeddings) — better long-context generalization.
  • Use SwiGLU MLP — ~2% perplexity improvement for ~50% more MLP params.
  • Compute Chinchilla compute-optimal for your params: tokens ≈ 20 × params. For 10M params, train on 200M tokens.
  • Run on FineWeb-Edu sample instead of TinyStories — better quality data, harder to learn from.
  • Visualize attention at a checkpoint: pick a position, plot attention weights across all layers. Identify induction heads.

8. What this lab proves about you

You can run a complete pretraining loop end-to-end, choose every hyperparameter with justification, debug loss-curve pathologies, and ship a generating model from raw text. This is the bar Anthropic/OpenAI use for applied research engineers — the difference between someone who knows transformers and someone who can train them.