Lab 01 — KV-Cache From Scratch (Solution Walkthrough)
Phase: 9 — Inference & Serving | Difficulty: ⭐⭐⭐⭐☆ | Time: 3–5 hours
Concept primer:
../HITCHHIKERS-GUIDE.md§KV cache, §Prefill vs decode, §PagedAttention.
Run
pip install -r requirements.txt
python solution.py
0. The mission
Retrofit the Phase-4 transformer with a KV cache and measure the speedup. This is the single most important inference optimization — every production engine (vLLM, TGI, TensorRT-LLM, llama.cpp) is structured around managing this cache.
The two questions you must answer at the end:
- Why is decoding without a cache
O(T²)per generated token? - Why does the cache reduce it to
O(T)and enable continuous batching?
1. The math
For a sequence of length $T$, attention costs:
$$ \text{Attention FLOPs} \approx 4 T^2 d $$
(quadratic in $T$). When generating token $T+1$, without a cache you re-process tokens $1..T$ from scratch → each generated token is $O(T^2)$. Total cost to generate $N$ tokens from prompt of length $P$:
$$ \sum_{t=P}^{P+N} O(t^2) = O!\left((P+N)^3\right) $$
With a KV cache, when generating token $T+1$:
- Compute Q only for the new token (1 token).
- Look up cached K, V for tokens $1..T$.
- Compute attention as $q \cdot K^\top$ which is $O(T \cdot d)$.
Generating $N$ tokens after prefilling $P$:
$$ O(P^2) \text{ for prefill} + \sum_{t=P}^{P+N} O(t \cdot d) = O((P+N)^2) $$
For $P=128, N=128$: cube vs square → ~256× fewer FLOPs.
2. The two phases of inference
The single most important conceptual split in serving:
| Phase | Input | Compute character | Bottleneck |
|---|---|---|---|
| Prefill | All P prompt tokens at once | Compute-bound (big matmul) | TFLOPS |
| Decode | One token at a time, T times | Memory-bound (tiny matmul, big weight load) | Memory bandwidth |
Metrics map directly:
- TTFT (time to first token) = prefill latency.
- ITL (inter-token latency) = decode latency.
Batching helps decode hugely (each batch element shares the weight load) but barely helps prefill (already compute-saturated). This is why continuous batching dynamically merges incoming requests — they spend most of their time in decode anyway.
3. LayerCache — the data structure
@dataclass
class LayerCache:
k: torch.Tensor | None = None # (B, n_head, T_cur, d_head)
v: torch.Tensor | None = None
def append(self, new_k, new_v):
if self.k is None:
self.k = new_k
self.v = new_v
else:
self.k = torch.cat([self.k, new_k], dim=2)
self.v = torch.cat([self.v, new_v], dim=2)
return self.k, self.v
Design decisions:
- Per-layer cache — each transformer layer has its own K, V tensors. Total cache size =
n_layer × 2 × B × n_head × T × d_head × dtype_bytes. For Llama-7B at T=2048: ~1 GB per request. Why memory-bound serving is hard. - Concat on
dim=2(the time dim). Naive but correct. Production engines (vLLM) don't concat — they use paged allocation in fixed-size blocks (16 tokens) to avoid the O(T) reallocation and to enable shared prefix caching. - Naive concat is
O(T)per step — every decode step copies the entire growing cache. For long contexts this becomes a bottleneck. Pre-allocating a max-size buffer fixes this; PagedAttention generalizes the fix.
4. CachedSelfAttention — the modified forward
def forward(self, x, cache: LayerCache | None = None):
B, T, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.split(C, dim=-1)
q = q.view(B, T, self.n_head, self.d_head).transpose(1, 2)
k = k.view(B, T, self.n_head, self.d_head).transpose(1, 2)
v = v.view(B, T, self.n_head, self.d_head).transpose(1, 2)
if cache is not None:
k, v = cache.append(k, v) # 👈 prepend cached K, V
T_total = k.size(2)
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)
# Causal mask: q at offset (T_total - T) attends to k[: T_total - T + i + 1]
if cache is None or T > 1: # prefill or no cache
mask = torch.tril(torch.ones(T, T_total, dtype=torch.bool, device=x.device))
att = att.masked_fill(~mask, float("-inf"))
# decode (T == 1) needs no mask: q can attend to all of k by definition
att = F.softmax(att, dim=-1)
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.proj(y)
Three changes from Phase 4's attention:
- K, V come from concatenation: new K, V for the just-arrived tokens; previous K, V from the cache.
- Q is only for the new tokens (length
T), but K, V cover the full lengthT_total. - Mask shape is
(T, T_total)— rows are queries, cols are keys. Positioni(in the new tokens) corresponds to absolute positionT_total - T + i, and can attend to keys0..T_total - T + i.
During decode (T == 1), the mask is trivially "attend to everything" — we skip computing it.
5. The two generation paths
5.1 Reference: no-cache generation
@torch.no_grad()
def generate_no_cache(model, prompt, max_new):
out = prompt.clone()
for _ in range(max_new):
logits = model(out) # 👈 reprocesses the entire sequence
next_id = logits[:, -1, :].argmax(-1, keepdim=True)
out = torch.cat([out, next_id], dim=1)
return out
Cost: each iteration re-runs the full transformer over the entire current sequence. Sublime in its inefficiency.
5.2 With cache
@torch.no_grad()
def generate_kv_cache(model, prompt, max_new):
caches = [LayerCache() for _ in range(model.n_layer)]
# PREFILL: process the prompt once, populate caches
logits = model(prompt, caches=caches)
next_id = logits[:, -1, :].argmax(-1, keepdim=True)
out = torch.cat([prompt, next_id], dim=1)
# DECODE: feed only the new token, reuse caches
for _ in range(max_new - 1):
logits = model(next_id, caches=caches) # input is (B, 1)
next_id = logits[:, -1, :].argmax(-1, keepdim=True)
out = torch.cat([out, next_id], dim=1)
return out
cachesis a list ofLayerCache, one per transformer block. Mutated in-place by each forward pass.- Prefill consumes the prompt; decode steps consume one token each.
- The model's
forwardaccepts an optionalcacheslist and threads them to the right attention layers.
6. Correctness verification
The most important test:
out1 = generate_no_cache(model, prompt, max_new=64)
out2 = generate_kv_cache(model, prompt, max_new=64)
assert torch.equal(out1, out2), "KV cache must produce identical tokens"
Because we use greedy (argmax), both paths must produce exactly identical output sequences. If they differ:
- Off-by-one in cache appending (you doubled the new tokens).
- Wrong mask shape during decode (
T == 1case). - Position embedding bug — you forgot to advance positions during decode.
If you're using sampled (non-deterministic) generation, fix the seed and the same property holds.
7. Position embeddings during decode
With learned absolute position embeddings:
def forward(self, idx, caches=None):
B, T = idx.shape
past_len = caches[0].k.size(2) if caches and caches[0].k is not None else 0
pos = torch.arange(past_len, past_len + T, device=idx.device)
x = self.tok_emb(idx) + self.pos_emb(pos)
...
The new tokens get positions past_len, past_len+1, .... Forgetting this means decode tokens always get position 0 → model is confused about ordering → outputs degrade after the first decoded token.
With RoPE, the same logic but applied as rotation inside attention. With ALiBi, you don't need anything (the mask itself encodes position).
8. The benchmark
import time
for seq_len in [64, 128, 256, 512]:
prompt = torch.randint(0, V, (1, seq_len), device=device)
t0 = time.perf_counter()
_ = generate_no_cache(model, prompt, max_new=64)
t_naive = time.perf_counter() - t0
t0 = time.perf_counter()
_ = generate_kv_cache(model, prompt, max_new=64)
t_cache = time.perf_counter() - t0
print(f"prompt={seq_len:4d} naive={t_naive*1000:.1f}ms cache={t_cache*1000:.1f}ms speedup={t_naive/t_cache:.1f}×")
Expected (small model, RTX 4090):
prompt= 64 naive= 480ms cache= 62ms speedup= 7.7×
prompt= 128 naive= 920ms cache= 74ms speedup=12.4×
prompt= 256 naive=2100ms cache= 98ms speedup=21.4×
prompt= 512 naive=6800ms cache= 145ms speedup=46.9×
Speedup grows with prompt length because the no-cache cost is cubic. For real LLM serving (prompts of 1k–10k tokens), the no-cache path is unusable.
9. From this lab to vLLM
What vLLM adds on top of what you just built:
| Feature | What | Why |
|---|---|---|
| PagedAttention | KV cache stored in fixed-size blocks (16 tokens), virtualized | Eliminates fragmentation; enables prefix caching |
| Continuous batching | New requests join the running batch at decode-step boundaries | 2–5× throughput vs static batching |
| Prefix caching | Reuse KV across requests sharing a prompt prefix | Massive speedup for system-prompt-heavy workloads |
| Speculative decoding | Small draft model proposes tokens; big model verifies | 2–3× latency reduction |
| FlashAttention | Fused, IO-aware attention kernel | 2–3× attention speedup |
| Quantization | INT8/INT4/FP8 weights and KV cache | Fit bigger models / longer contexts |
You now have the conceptual foundation to read vLLM's source code without it feeling magical.
10. Common pitfalls
- Forgetting to advance position embeddings during decode — quality silently degrades.
- Mask shape
(T, T)instead of(T, T_total)during decode — crash or wrong attention. - Re-creating
LayerCacheper decode step — must persist across the decode loop. - Not using
@torch.no_grad()— OOM on long generations. - Confusing prefill and decode paths — must handle both correctly: T > 1 for prefill, T == 1 for decode.
- Comparing wall-time without warmup — first run has CUDA kernel compilation; always discard the first iteration.
11. Stretch exercises
- Pre-allocate the cache to
max_seq_leninstead of concatenating. Compare speed. - Implement paged caching: store K/V in fixed blocks (e.g., 16 tokens each), use a block table for indirection. Foundation of vLLM.
- Add prefix caching: detect when two sequences share a prefix; share the K/V blocks. ~5–1000× speedup for repeated system prompts.
- Implement speculative decoding: draft with a small model, verify with the big one. The hardest exercise; 2–3× latency reduction at the cost of complexity.
- Quantize the KV cache to INT8: store K, V as INT8 with per-channel scale; dequantize before attention. Halves cache memory.
- Profile with
nsys: prove that decode is memory-bound (low compute utilization, high DRAM read bandwidth). - Plug in FlashAttention: replace the manual
(q @ k.T) / sqrt(d) ; softmax ; @ vwithF.scaled_dot_product_attention. Re-benchmark.
12. What this lab proves about you
You understand inference at the level required for LLM Inference Engineer roles. You can:
- Explain why decode is memory-bound and prefill is compute-bound.
- Articulate the math behind the O(T²) → O(T) speedup.
- Implement (and debug) a KV cache from scratch.
- Read vLLM's source and connect every concept to your implementation.
This is the highest-leverage Phase-9 milestone — KV cache + continuous batching + PagedAttention is essentially the entire interview surface for inference roles.