Lab 01 — KV-Cache From Scratch (Solution Walkthrough)

Phase: 9 — Inference & Serving | Difficulty: ⭐⭐⭐⭐☆ | Time: 3–5 hours

Concept primer: ../HITCHHIKERS-GUIDE.md §KV cache, §Prefill vs decode, §PagedAttention.

Run

pip install -r requirements.txt
python solution.py

0. The mission

Retrofit the Phase-4 transformer with a KV cache and measure the speedup. This is the single most important inference optimization — every production engine (vLLM, TGI, TensorRT-LLM, llama.cpp) is structured around managing this cache.

The two questions you must answer at the end:

  1. Why is decoding without a cache O(T²) per generated token?
  2. Why does the cache reduce it to O(T) and enable continuous batching?

1. The math

For a sequence of length $T$, attention costs:

$$ \text{Attention FLOPs} \approx 4 T^2 d $$

(quadratic in $T$). When generating token $T+1$, without a cache you re-process tokens $1..T$ from scratch → each generated token is $O(T^2)$. Total cost to generate $N$ tokens from prompt of length $P$:

$$ \sum_{t=P}^{P+N} O(t^2) = O!\left((P+N)^3\right) $$

With a KV cache, when generating token $T+1$:

  • Compute Q only for the new token (1 token).
  • Look up cached K, V for tokens $1..T$.
  • Compute attention as $q \cdot K^\top$ which is $O(T \cdot d)$.

Generating $N$ tokens after prefilling $P$:

$$ O(P^2) \text{ for prefill} + \sum_{t=P}^{P+N} O(t \cdot d) = O((P+N)^2) $$

For $P=128, N=128$: cube vs square → ~256× fewer FLOPs.


2. The two phases of inference

The single most important conceptual split in serving:

PhaseInputCompute characterBottleneck
PrefillAll P prompt tokens at onceCompute-bound (big matmul)TFLOPS
DecodeOne token at a time, T timesMemory-bound (tiny matmul, big weight load)Memory bandwidth

Metrics map directly:

  • TTFT (time to first token) = prefill latency.
  • ITL (inter-token latency) = decode latency.

Batching helps decode hugely (each batch element shares the weight load) but barely helps prefill (already compute-saturated). This is why continuous batching dynamically merges incoming requests — they spend most of their time in decode anyway.


3. LayerCache — the data structure

@dataclass
class LayerCache:
    k: torch.Tensor | None = None    # (B, n_head, T_cur, d_head)
    v: torch.Tensor | None = None

    def append(self, new_k, new_v):
        if self.k is None:
            self.k = new_k
            self.v = new_v
        else:
            self.k = torch.cat([self.k, new_k], dim=2)
            self.v = torch.cat([self.v, new_v], dim=2)
        return self.k, self.v

Design decisions:

  • Per-layer cache — each transformer layer has its own K, V tensors. Total cache size = n_layer × 2 × B × n_head × T × d_head × dtype_bytes. For Llama-7B at T=2048: ~1 GB per request. Why memory-bound serving is hard.
  • Concat on dim=2 (the time dim). Naive but correct. Production engines (vLLM) don't concat — they use paged allocation in fixed-size blocks (16 tokens) to avoid the O(T) reallocation and to enable shared prefix caching.
  • Naive concat is O(T) per step — every decode step copies the entire growing cache. For long contexts this becomes a bottleneck. Pre-allocating a max-size buffer fixes this; PagedAttention generalizes the fix.

4. CachedSelfAttention — the modified forward

def forward(self, x, cache: LayerCache | None = None):
    B, T, C = x.shape
    qkv = self.qkv(x)
    q, k, v = qkv.split(C, dim=-1)
    q = q.view(B, T, self.n_head, self.d_head).transpose(1, 2)
    k = k.view(B, T, self.n_head, self.d_head).transpose(1, 2)
    v = v.view(B, T, self.n_head, self.d_head).transpose(1, 2)

    if cache is not None:
        k, v = cache.append(k, v)              # 👈 prepend cached K, V

    T_total = k.size(2)
    att = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)

    # Causal mask: q at offset (T_total - T) attends to k[: T_total - T + i + 1]
    if cache is None or T > 1:                 # prefill or no cache
        mask = torch.tril(torch.ones(T, T_total, dtype=torch.bool, device=x.device))
        att = att.masked_fill(~mask, float("-inf"))
    # decode (T == 1) needs no mask: q can attend to all of k by definition

    att = F.softmax(att, dim=-1)
    y = att @ v
    y = y.transpose(1, 2).contiguous().view(B, T, C)
    return self.proj(y)

Three changes from Phase 4's attention:

  1. K, V come from concatenation: new K, V for the just-arrived tokens; previous K, V from the cache.
  2. Q is only for the new tokens (length T), but K, V cover the full length T_total.
  3. Mask shape is (T, T_total) — rows are queries, cols are keys. Position i (in the new tokens) corresponds to absolute position T_total - T + i, and can attend to keys 0..T_total - T + i.

During decode (T == 1), the mask is trivially "attend to everything" — we skip computing it.


5. The two generation paths

5.1 Reference: no-cache generation

@torch.no_grad()
def generate_no_cache(model, prompt, max_new):
    out = prompt.clone()
    for _ in range(max_new):
        logits = model(out)              # 👈 reprocesses the entire sequence
        next_id = logits[:, -1, :].argmax(-1, keepdim=True)
        out = torch.cat([out, next_id], dim=1)
    return out

Cost: each iteration re-runs the full transformer over the entire current sequence. Sublime in its inefficiency.

5.2 With cache

@torch.no_grad()
def generate_kv_cache(model, prompt, max_new):
    caches = [LayerCache() for _ in range(model.n_layer)]

    # PREFILL: process the prompt once, populate caches
    logits = model(prompt, caches=caches)
    next_id = logits[:, -1, :].argmax(-1, keepdim=True)
    out = torch.cat([prompt, next_id], dim=1)

    # DECODE: feed only the new token, reuse caches
    for _ in range(max_new - 1):
        logits = model(next_id, caches=caches)     # input is (B, 1)
        next_id = logits[:, -1, :].argmax(-1, keepdim=True)
        out = torch.cat([out, next_id], dim=1)
    return out
  • caches is a list of LayerCache, one per transformer block. Mutated in-place by each forward pass.
  • Prefill consumes the prompt; decode steps consume one token each.
  • The model's forward accepts an optional caches list and threads them to the right attention layers.

6. Correctness verification

The most important test:

out1 = generate_no_cache(model, prompt, max_new=64)
out2 = generate_kv_cache(model, prompt, max_new=64)
assert torch.equal(out1, out2), "KV cache must produce identical tokens"

Because we use greedy (argmax), both paths must produce exactly identical output sequences. If they differ:

  • Off-by-one in cache appending (you doubled the new tokens).
  • Wrong mask shape during decode (T == 1 case).
  • Position embedding bug — you forgot to advance positions during decode.

If you're using sampled (non-deterministic) generation, fix the seed and the same property holds.


7. Position embeddings during decode

With learned absolute position embeddings:

def forward(self, idx, caches=None):
    B, T = idx.shape
    past_len = caches[0].k.size(2) if caches and caches[0].k is not None else 0
    pos = torch.arange(past_len, past_len + T, device=idx.device)
    x = self.tok_emb(idx) + self.pos_emb(pos)
    ...

The new tokens get positions past_len, past_len+1, .... Forgetting this means decode tokens always get position 0 → model is confused about ordering → outputs degrade after the first decoded token.

With RoPE, the same logic but applied as rotation inside attention. With ALiBi, you don't need anything (the mask itself encodes position).


8. The benchmark

import time

for seq_len in [64, 128, 256, 512]:
    prompt = torch.randint(0, V, (1, seq_len), device=device)

    t0 = time.perf_counter()
    _ = generate_no_cache(model, prompt, max_new=64)
    t_naive = time.perf_counter() - t0

    t0 = time.perf_counter()
    _ = generate_kv_cache(model, prompt, max_new=64)
    t_cache = time.perf_counter() - t0

    print(f"prompt={seq_len:4d}  naive={t_naive*1000:.1f}ms  cache={t_cache*1000:.1f}ms  speedup={t_naive/t_cache:.1f}×")

Expected (small model, RTX 4090):

prompt=  64  naive= 480ms  cache=  62ms  speedup= 7.7×
prompt= 128  naive= 920ms  cache=  74ms  speedup=12.4×
prompt= 256  naive=2100ms  cache=  98ms  speedup=21.4×
prompt= 512  naive=6800ms  cache= 145ms  speedup=46.9×

Speedup grows with prompt length because the no-cache cost is cubic. For real LLM serving (prompts of 1k–10k tokens), the no-cache path is unusable.


9. From this lab to vLLM

What vLLM adds on top of what you just built:

FeatureWhatWhy
PagedAttentionKV cache stored in fixed-size blocks (16 tokens), virtualizedEliminates fragmentation; enables prefix caching
Continuous batchingNew requests join the running batch at decode-step boundaries2–5× throughput vs static batching
Prefix cachingReuse KV across requests sharing a prompt prefixMassive speedup for system-prompt-heavy workloads
Speculative decodingSmall draft model proposes tokens; big model verifies2–3× latency reduction
FlashAttentionFused, IO-aware attention kernel2–3× attention speedup
QuantizationINT8/INT4/FP8 weights and KV cacheFit bigger models / longer contexts

You now have the conceptual foundation to read vLLM's source code without it feeling magical.


10. Common pitfalls

  1. Forgetting to advance position embeddings during decode — quality silently degrades.
  2. Mask shape (T, T) instead of (T, T_total) during decode — crash or wrong attention.
  3. Re-creating LayerCache per decode step — must persist across the decode loop.
  4. Not using @torch.no_grad() — OOM on long generations.
  5. Confusing prefill and decode paths — must handle both correctly: T > 1 for prefill, T == 1 for decode.
  6. Comparing wall-time without warmup — first run has CUDA kernel compilation; always discard the first iteration.

11. Stretch exercises

  • Pre-allocate the cache to max_seq_len instead of concatenating. Compare speed.
  • Implement paged caching: store K/V in fixed blocks (e.g., 16 tokens each), use a block table for indirection. Foundation of vLLM.
  • Add prefix caching: detect when two sequences share a prefix; share the K/V blocks. ~5–1000× speedup for repeated system prompts.
  • Implement speculative decoding: draft with a small model, verify with the big one. The hardest exercise; 2–3× latency reduction at the cost of complexity.
  • Quantize the KV cache to INT8: store K, V as INT8 with per-channel scale; dequantize before attention. Halves cache memory.
  • Profile with nsys: prove that decode is memory-bound (low compute utilization, high DRAM read bandwidth).
  • Plug in FlashAttention: replace the manual (q @ k.T) / sqrt(d) ; softmax ; @ v with F.scaled_dot_product_attention. Re-benchmark.

12. What this lab proves about you

You understand inference at the level required for LLM Inference Engineer roles. You can:

  • Explain why decode is memory-bound and prefill is compute-bound.
  • Articulate the math behind the O(T²) → O(T) speedup.
  • Implement (and debug) a KV cache from scratch.
  • Read vLLM's source and connect every concept to your implementation.

This is the highest-leverage Phase-9 milestone — KV cache + continuous batching + PagedAttention is essentially the entire interview surface for inference roles.