Lab 04 — Mini Transformer (Solution Walkthrough)

Phase: 4 — Attention & Transformers | Difficulty: ⭐⭐⭐⭐☆ | Time: 4–6 hours

Concept primer: ../HITCHHIKERS-GUIDE.md §Attention and §Transformer architecture. This is the most important lab in the curriculum — every later phase reuses or extends this code.

Run

pip install -r requirements.txt
python solution.py   # runs init-loss + single-batch overfit sanity checks

0. The mission

Build a decoder-only Transformer from scratch in ~200 lines that:

  • Implements scaled dot-product attention with causal masking.
  • Uses pre-norm with residual connections.
  • Passes the two universal sanity tests: init-loss matches the entropy of the uniform vocab distribution, and the model can overfit a single batch to ~zero loss in 200 steps.

This is the kernel that Phase 5 trains on TinyStories, Phase 6 fine-tunes via LoRA, Phase 9 retro-fits with a KV cache. Get this right and the rest of the curriculum compiles.


1. The math

For each token position $t$:

$$ \mathrm{Attn}(Q, K, V) = \mathrm{softmax}!\left(\frac{Q K^\top}{\sqrt{d_\text{head}}} + M\right) V $$

where $M$ is the causal mask: $M_{ij} = 0$ if $i \ge j$ else $-\infty$. Multi-head attention runs n_head of these in parallel on slices of dim d_head = d_model / n_head, then concatenates.

The full block (pre-norm):

$$ \begin{aligned} x &\leftarrow x + \mathrm{Attn}(\mathrm{LN}(x)) \ x &\leftarrow x + \mathrm{MLP}(\mathrm{LN}(x)) \end{aligned} $$

A model is n_layer blocks stacked plus token + position embeddings at the input and a linear head at the output.


2. GPTConfig

@dataclass
class GPTConfig:
    vocab_size: int = 50257
    n_layer: int = 6
    n_head: int = 8
    d_model: int = 512
    d_ff: int = 2048           # typically 4 * d_model
    block_size: int = 1024     # max context length
    dropout: float = 0.0
    tie_weights: bool = True
  • vocab_size = 50257 matches GPT-2 BPE.
  • d_ff = 4 * d_model is the universal heuristic from "Attention Is All You Need" — gives the MLP enough capacity to act as the model's "memory" (Geva et al. 2021 showed MLP weights store factual knowledge).
  • block_size is the maximum sequence the position-embedding table supports.
  • tie_weights=True shares the vocab × d_model matrix between the input embedding and the output head — saves ~50 MB on a small model, ~1 GB on 7B. Quality identical or slightly better.

3. CausalSelfAttention — the centerpiece

3.1 The fused QKV projection

self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)

One large matmul is faster than three smaller ones (better GPU utilization). Mathematically identical to three separate linears. bias=False is the modern default — biases add parameters without measurable quality benefit at scale.

3.2 The causal mask buffer

self.register_buffer(
    "mask",
    torch.tril(torch.ones(cfg.block_size, cfg.block_size, dtype=torch.bool))
         .view(1, 1, cfg.block_size, cfg.block_size),
    persistent=False,
)
  • torch.tril(...) gives a lower-triangular boolean matrix: True on and below the diagonal. Position i can attend to position j iff i ≥ j.
  • Shape (1, 1, T, T) so it broadcasts over batch and head dims.
  • register_buffer so the mask moves to GPU with .to(device). persistent=False keeps it out of state_dict (deterministically reconstructable).

3.3 The forward — six lines that contain the whole transformer

def forward(self, x):
    B, T, C = x.shape
    qkv = self.qkv(x)                                              # (B, T, 3C)
    q, k, v = qkv.split(C, dim=-1)
    q = q.view(B, T, self.n_head, self.d_head).transpose(1, 2)
    k = k.view(B, T, self.n_head, self.d_head).transpose(1, 2)
    v = v.view(B, T, self.n_head, self.d_head).transpose(1, 2)
    att = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)
    att = att.masked_fill(~self.mask[:, :, :T, :T], float("-inf"))
    att = F.softmax(att, dim=-1)
    att = self.attn_drop(att)
    y = att @ v
    y = y.transpose(1, 2).contiguous().view(B, T, C)
    return self.resid_drop(self.proj(y))

Decoded:

  1. qkv.split(C, -1) — split the fused projection into Q, K, V each of shape (B, T, C).
  2. view + transpose(1, 2) — reshape to (B, n_head, T, d_head). The transpose is the canonical position for the head dim; what cuBLAS expects for batched matmul efficiency.
  3. q @ k.transpose(-2, -1) — batched matmul → attention scores (B, n_head, T, T).
  4. / math.sqrt(self.d_head)the most important divisor in deep learning. Without it, scores have variance d_head, push softmax into saturation, gradients vanish.
  5. masked_fill(~mask, -inf)-inf not -1e9 because -1e9 plus a moderately positive score can still produce >1e-30 after softmax, polluting attention.
  6. softmax(dim=-1) — normalize across the key dimension. Each row sums to 1.
  7. att @ v(B, n_head, T, d_head) — weighted sum of values.
  8. transpose(1, 2).contiguous().view(B, T, C) — un-do the head split. contiguous() is required before view because transpose only changes strides.
  9. self.proj(y) — output projection (per-block recombination of head info).

3.4 Why two dropouts?

attn_drop masks attention weights (random tokens become "ignored"); resid_drop masks the output before adding to residual stream. Both at 0 in this skeleton — turn on for fine-tuning small datasets.


4. MLP

class MLP(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.fc = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
        self.proj = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
        self.drop = nn.Dropout(cfg.dropout)

    def forward(self, x):
        return self.drop(self.proj(F.gelu(self.fc(x))))

GELU = x * Φ(x) (smooth ReLU); empirically better than ReLU for transformers.

Modern variants use SwiGLU (Llama, Qwen): (SiLU(W_g x)) * (W_u x) then W_d. Three matrices instead of two — adds 50% MLP params, gives ~2% perplexity improvement.


5. Block — the pre-norm layout

class Block(nn.Module):
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

Pre-norm vs post-norm matters more than any other architecture choice:

  • Post-norm (original 2017 paper): x = LN(x + sublayer(x)). Trains poorly without warmup; gradients pass through LN on every residual.
  • Pre-norm (GPT-2 onwards): x = x + sublayer(LN(x)). Residual stream is "clean" — gradients flow unimpeded through every layer. Trains stably without warmup at any depth.

Modern alternative: RMSNorm (Llama) — drops mean-subtraction; ~10% faster, identical quality.


6. MiniGPT

self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
self.pos_emb = nn.Embedding(cfg.block_size, cfg.d_model)
self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layer)])
self.ln_f = nn.LayerNorm(cfg.d_model)
self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
if cfg.tie_weights:
    self.head.weight = self.tok_emb.weight
  • Learned absolute position embeddings (GPT-2 style). Modern models use RoPE (rotary, applied in attention itself) — handles longer contexts and extrapolates better.
  • Final LayerNorm before head (ln_f) — important for training stability.
  • Weight tying by direct assignment. Both tok_emb.weight and head.weight point to the same tensor → only one tensor in the optimizer.

6.1 Init

nn.init.normal_(m.weight, mean=0.0, std=0.02)

std=0.02 is GPT-2's choice. Theoretically 0.02 / sqrt(2 * n_layer) is better for residual-path projections (keeps activation variance constant across layers), but 0.02 everywhere works fine for small models.

6.2 generate

for _ in range(max_new_tokens):
    ctx = idx[:, -self.cfg.block_size:]
    logits, _ = self(ctx)
    logits = logits[:, -1, :] / max(1e-6, temperature)
    if top_k is not None:
        v, _ = torch.topk(logits, top_k)
        logits[logits < v[:, [-1]]] = float("-inf")
    probs = F.softmax(logits, dim=-1)
    next_id = torch.multinomial(probs, 1)
    idx = torch.cat([idx, next_id], dim=1)
  • ctx = idx[:, -block_size:] truncates context to the model's max — naive but correct. The KV-cache lab in Phase 9 makes this efficient.
  • This is O(T²) per generated token because we re-process the entire context. Phase 9 fixes this with KV cache → O(T).

7. The two sanity tests

These should be the first things you run on any from-scratch transformer.

7.1 Init loss

A randomly-initialized transformer should output approximately uniform logits. Cross-entropy of uniform over V classes is -log(1/V) = log(V). For V=1000, that's 6.91.

If init loss is way off:

  • Way higher → bad init scale; logits not centered around 0; softmax saturating.
  • Way lower → you accidentally have a constant-output bias somewhere.

7.2 Single-batch overfit

A correctly-wired transformer must memorize a single batch (loss → 0). If it can't:

  • Bug in the causal mask (try removing it — does it then overfit? If yes, your mask is upside-down).
  • Bug in residual connections (forgetting x = x + ...).
  • Bug in positional embeddings (model can't tell positions apart).
  • LR way too high (loss explodes) or too low (no progress).

Hitting final_loss < 0.5 in 200 steps confirms forward + backward + optimizer all wire correctly.


8. Expected output

params = 526,464
[init-loss]  got=6.9085  expected≈6.9078  ok=True
[overfit]   step  200  loss=0.0264  ok=True

If init-loss matches log(vocab_size) to two decimals and single-batch overfit drives loss < 0.5, your transformer is wired correctly.


9. Common pitfalls

  1. Forgetting / math.sqrt(d_head) — softmax saturates → gradients vanish.
  2. Mask shape mismatch when T < block_size → must slice with [:, :, :T, :T].
  3. Forgetting contiguous() before view after transpose → runtime error.
  4. Missing residualsx = self.attn(self.ln1(x)) (forgot the x +) — model trains but quality is terrible. Sanity tests catch this.
  5. Wrong mask directiontriu instead of tril → tokens attend only to the future. Loss might still go down but generation produces garbage.
  6. Tied weights only on init — must assign self.head.weight = self.tok_emb.weight not copy values.
  7. F.cross_entropy expects raw logits, not log-softmax. Don't double-softmax.

10. Stretch exercises

  • Implement RoPE (rotary positional embeddings). Apply rotation to Q, K inside attention. Drop the pos_emb table.
  • Implement RMSNorm. Replace LayerNorm. ~10 lines, ~10% faster.
  • Implement SwiGLU MLP.
  • Implement GQA (grouped-query attention). Set n_kv_head < n_head; broadcast K, V across query heads. Halves the KV cache.
  • Use torch.nn.functional.scaled_dot_product_attention to dispatch FlashAttention. Compare wall-clock — should be 2-3× faster at long contexts.
  • Profile with torch.profiler: where is time spent? (~60% matmuls, ~20% softmax, ~10% everything else.)
  • Reproduce the GPT-2 124M architecture exactly: 12 layers, 12 heads, d=768.

11. Connecting to later phases

PhaseWhat it adds to this code
5 (training)Real data loader, mixed precision, gradient accumulation, cosine LR.
6 (fine-tuning)LoRA adapters wrap Linear layers; QLoRA quantizes the base. Same forward, frozen base.
9 (inference)Adds a LayerCache to CausalSelfAttention, splits forward into prefill vs decode paths.
10 (distributed)Wraps MiniGPT in FSDP for sharding across GPUs.

You'll come back to this file 5+ times across the curriculum. Internalize it.


12. What this lab proves about you

You can implement causal multi-head attention from raw matmuls, articulate every design decision, verify correctness via init-loss + overfit, and modify it for new architectures (RoPE, SwiGLU, GQA) without breaking it. The bar for a Phase-4 milestone — and the single most-asked area of LLM interviews.