Lab 04 — Mini Transformer (Solution Walkthrough)
Phase: 4 — Attention & Transformers | Difficulty: ⭐⭐⭐⭐☆ | Time: 4–6 hours
Concept primer:
../HITCHHIKERS-GUIDE.md§Attention and §Transformer architecture. This is the most important lab in the curriculum — every later phase reuses or extends this code.
Run
pip install -r requirements.txt
python solution.py # runs init-loss + single-batch overfit sanity checks
0. The mission
Build a decoder-only Transformer from scratch in ~200 lines that:
- Implements scaled dot-product attention with causal masking.
- Uses pre-norm with residual connections.
- Passes the two universal sanity tests: init-loss matches the entropy of the uniform vocab distribution, and the model can overfit a single batch to ~zero loss in 200 steps.
This is the kernel that Phase 5 trains on TinyStories, Phase 6 fine-tunes via LoRA, Phase 9 retro-fits with a KV cache. Get this right and the rest of the curriculum compiles.
1. The math
For each token position $t$:
$$ \mathrm{Attn}(Q, K, V) = \mathrm{softmax}!\left(\frac{Q K^\top}{\sqrt{d_\text{head}}} + M\right) V $$
where $M$ is the causal mask: $M_{ij} = 0$ if $i \ge j$ else $-\infty$. Multi-head attention runs n_head of these in parallel on slices of dim d_head = d_model / n_head, then concatenates.
The full block (pre-norm):
$$ \begin{aligned} x &\leftarrow x + \mathrm{Attn}(\mathrm{LN}(x)) \ x &\leftarrow x + \mathrm{MLP}(\mathrm{LN}(x)) \end{aligned} $$
A model is n_layer blocks stacked plus token + position embeddings at the input and a linear head at the output.
2. GPTConfig
@dataclass
class GPTConfig:
vocab_size: int = 50257
n_layer: int = 6
n_head: int = 8
d_model: int = 512
d_ff: int = 2048 # typically 4 * d_model
block_size: int = 1024 # max context length
dropout: float = 0.0
tie_weights: bool = True
vocab_size = 50257matches GPT-2 BPE.d_ff = 4 * d_modelis the universal heuristic from "Attention Is All You Need" — gives the MLP enough capacity to act as the model's "memory" (Geva et al. 2021 showed MLP weights store factual knowledge).block_sizeis the maximum sequence the position-embedding table supports.tie_weights=Trueshares thevocab × d_modelmatrix between the input embedding and the output head — saves ~50 MB on a small model, ~1 GB on 7B. Quality identical or slightly better.
3. CausalSelfAttention — the centerpiece
3.1 The fused QKV projection
self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
One large matmul is faster than three smaller ones (better GPU utilization). Mathematically identical to three separate linears. bias=False is the modern default — biases add parameters without measurable quality benefit at scale.
3.2 The causal mask buffer
self.register_buffer(
"mask",
torch.tril(torch.ones(cfg.block_size, cfg.block_size, dtype=torch.bool))
.view(1, 1, cfg.block_size, cfg.block_size),
persistent=False,
)
torch.tril(...)gives a lower-triangular boolean matrix:Trueon and below the diagonal. Positionican attend to positionjiffi ≥ j.- Shape
(1, 1, T, T)so it broadcasts over batch and head dims. register_bufferso the mask moves to GPU with.to(device).persistent=Falsekeeps it out ofstate_dict(deterministically reconstructable).
3.3 The forward — six lines that contain the whole transformer
def forward(self, x):
B, T, C = x.shape
qkv = self.qkv(x) # (B, T, 3C)
q, k, v = qkv.split(C, dim=-1)
q = q.view(B, T, self.n_head, self.d_head).transpose(1, 2)
k = k.view(B, T, self.n_head, self.d_head).transpose(1, 2)
v = v.view(B, T, self.n_head, self.d_head).transpose(1, 2)
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)
att = att.masked_fill(~self.mask[:, :, :T, :T], float("-inf"))
att = F.softmax(att, dim=-1)
att = self.attn_drop(att)
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.resid_drop(self.proj(y))
Decoded:
qkv.split(C, -1)— split the fused projection into Q, K, V each of shape(B, T, C).view + transpose(1, 2)— reshape to(B, n_head, T, d_head). The transpose is the canonical position for the head dim; what cuBLAS expects for batched matmul efficiency.q @ k.transpose(-2, -1)— batched matmul → attention scores(B, n_head, T, T)./ math.sqrt(self.d_head)— the most important divisor in deep learning. Without it, scores have varianced_head, push softmax into saturation, gradients vanish.masked_fill(~mask, -inf)—-infnot-1e9because-1e9plus a moderately positive score can still produce>1e-30after softmax, polluting attention.softmax(dim=-1)— normalize across the key dimension. Each row sums to 1.att @ v→(B, n_head, T, d_head)— weighted sum of values.transpose(1, 2).contiguous().view(B, T, C)— un-do the head split.contiguous()is required beforeviewbecausetransposeonly changes strides.self.proj(y)— output projection (per-block recombination of head info).
3.4 Why two dropouts?
attn_drop masks attention weights (random tokens become "ignored"); resid_drop masks the output before adding to residual stream. Both at 0 in this skeleton — turn on for fine-tuning small datasets.
4. MLP
class MLP(nn.Module):
def __init__(self, cfg):
super().__init__()
self.fc = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
self.proj = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
self.drop = nn.Dropout(cfg.dropout)
def forward(self, x):
return self.drop(self.proj(F.gelu(self.fc(x))))
GELU = x * Φ(x) (smooth ReLU); empirically better than ReLU for transformers.
Modern variants use SwiGLU (Llama, Qwen): (SiLU(W_g x)) * (W_u x) then W_d. Three matrices instead of two — adds 50% MLP params, gives ~2% perplexity improvement.
5. Block — the pre-norm layout
class Block(nn.Module):
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
Pre-norm vs post-norm matters more than any other architecture choice:
- Post-norm (original 2017 paper):
x = LN(x + sublayer(x)). Trains poorly without warmup; gradients pass throughLNon every residual. - Pre-norm (GPT-2 onwards):
x = x + sublayer(LN(x)). Residual stream is "clean" — gradients flow unimpeded through every layer. Trains stably without warmup at any depth.
Modern alternative: RMSNorm (Llama) — drops mean-subtraction; ~10% faster, identical quality.
6. MiniGPT
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
self.pos_emb = nn.Embedding(cfg.block_size, cfg.d_model)
self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layer)])
self.ln_f = nn.LayerNorm(cfg.d_model)
self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
if cfg.tie_weights:
self.head.weight = self.tok_emb.weight
- Learned absolute position embeddings (GPT-2 style). Modern models use RoPE (rotary, applied in attention itself) — handles longer contexts and extrapolates better.
- Final LayerNorm before head (
ln_f) — important for training stability. - Weight tying by direct assignment. Both
tok_emb.weightandhead.weightpoint to the same tensor → only one tensor in the optimizer.
6.1 Init
nn.init.normal_(m.weight, mean=0.0, std=0.02)
std=0.02 is GPT-2's choice. Theoretically 0.02 / sqrt(2 * n_layer) is better for residual-path projections (keeps activation variance constant across layers), but 0.02 everywhere works fine for small models.
6.2 generate
for _ in range(max_new_tokens):
ctx = idx[:, -self.cfg.block_size:]
logits, _ = self(ctx)
logits = logits[:, -1, :] / max(1e-6, temperature)
if top_k is not None:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = float("-inf")
probs = F.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, 1)
idx = torch.cat([idx, next_id], dim=1)
ctx = idx[:, -block_size:]truncates context to the model's max — naive but correct. The KV-cache lab in Phase 9 makes this efficient.- This is
O(T²)per generated token because we re-process the entire context. Phase 9 fixes this with KV cache →O(T).
7. The two sanity tests
These should be the first things you run on any from-scratch transformer.
7.1 Init loss
A randomly-initialized transformer should output approximately uniform logits. Cross-entropy of uniform over V classes is -log(1/V) = log(V). For V=1000, that's 6.91.
If init loss is way off:
- Way higher → bad init scale; logits not centered around 0; softmax saturating.
- Way lower → you accidentally have a constant-output bias somewhere.
7.2 Single-batch overfit
A correctly-wired transformer must memorize a single batch (loss → 0). If it can't:
- Bug in the causal mask (try removing it — does it then overfit? If yes, your mask is upside-down).
- Bug in residual connections (forgetting
x = x + ...). - Bug in positional embeddings (model can't tell positions apart).
- LR way too high (loss explodes) or too low (no progress).
Hitting final_loss < 0.5 in 200 steps confirms forward + backward + optimizer all wire correctly.
8. Expected output
params = 526,464
[init-loss] got=6.9085 expected≈6.9078 ok=True
[overfit] step 200 loss=0.0264 ok=True
If init-loss matches log(vocab_size) to two decimals and single-batch overfit drives loss < 0.5, your transformer is wired correctly.
9. Common pitfalls
- Forgetting
/ math.sqrt(d_head)— softmax saturates → gradients vanish. - Mask shape mismatch when
T < block_size→ must slice with[:, :, :T, :T]. - Forgetting
contiguous()beforeviewaftertranspose→ runtime error. - Missing residuals —
x = self.attn(self.ln1(x))(forgot thex +) — model trains but quality is terrible. Sanity tests catch this. - Wrong mask direction —
triuinstead oftril→ tokens attend only to the future. Loss might still go down but generation produces garbage. - Tied weights only on init — must assign
self.head.weight = self.tok_emb.weightnot copy values. F.cross_entropyexpects raw logits, not log-softmax. Don't double-softmax.
10. Stretch exercises
- Implement RoPE (rotary positional embeddings). Apply rotation to Q, K inside attention. Drop the
pos_embtable. - Implement RMSNorm. Replace
LayerNorm. ~10 lines, ~10% faster. - Implement SwiGLU MLP.
- Implement GQA (grouped-query attention). Set
n_kv_head < n_head; broadcast K, V across query heads. Halves the KV cache. - Use
torch.nn.functional.scaled_dot_product_attentionto dispatch FlashAttention. Compare wall-clock — should be 2-3× faster at long contexts. - Profile with
torch.profiler: where is time spent? (~60% matmuls, ~20% softmax, ~10% everything else.) - Reproduce the GPT-2 124M architecture exactly: 12 layers, 12 heads, d=768.
11. Connecting to later phases
| Phase | What it adds to this code |
|---|---|
| 5 (training) | Real data loader, mixed precision, gradient accumulation, cosine LR. |
| 6 (fine-tuning) | LoRA adapters wrap Linear layers; QLoRA quantizes the base. Same forward, frozen base. |
| 9 (inference) | Adds a LayerCache to CausalSelfAttention, splits forward into prefill vs decode paths. |
| 10 (distributed) | Wraps MiniGPT in FSDP for sharding across GPUs. |
You'll come back to this file 5+ times across the curriculum. Internalize it.
12. What this lab proves about you
You can implement causal multi-head attention from raw matmuls, articulate every design decision, verify correctness via init-loss + overfit, and modify it for new architectures (RoPE, SwiGLU, GQA) without breaking it. The bar for a Phase-4 milestone — and the single most-asked area of LLM interviews.