"""
LLM Coding Questions — implement-from-scratch challenges.

Each function/class has a signature + docstring + tests. Fill in the body, then
run `python 02-llm-coding-questions.py` to check.

Topics:
  Q1. Scaled dot-product attention (single-head)
  Q2. Multi-head attention with causal mask
  Q3. KV-cache wrapper
  Q4. BPE encoder (given trained merges)
  Q5. Top-p (nucleus) sampling
  Q6. Beam search (length-normalized)
  Q7. RoPE (rotary positional embeddings) apply to Q and K
  Q8. Token-aware text chunker (sliding window)
  Q9. Sliding-window perplexity over a long doc
  Q10. Speculative-decoding accept loop (rejection sampling)
"""
from __future__ import annotations
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


# ---------------------------------------------------------------------------
# Q1. Scaled dot-product attention (single-head)
# ---------------------------------------------------------------------------
def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
              mask: torch.Tensor | None = None) -> torch.Tensor:
    """
    q, k, v: (B, T, D)
    mask: (T, T) bool, True = keep, or None
    Returns: (B, T, D)
    """
    d = q.size(-1)
    scores = q @ k.transpose(-2, -1) / math.sqrt(d)
    if mask is not None:
        scores = scores.masked_fill(~mask, float("-inf"))
    return F.softmax(scores, dim=-1) @ v


# ---------------------------------------------------------------------------
# Q2. Multi-head attention with causal mask
# ---------------------------------------------------------------------------
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.proj = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape
        q, k, v = self.qkv(x).split(C, dim=-1)
        q = q.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        scores = q @ k.transpose(-2, -1) / math.sqrt(self.d_head)
        causal = torch.tril(torch.ones(T, T, dtype=torch.bool, device=x.device))
        scores = scores.masked_fill(~causal, float("-inf"))
        att = F.softmax(scores, dim=-1)
        y = (att @ v).transpose(1, 2).contiguous().view(B, T, C)
        return self.proj(y)


# ---------------------------------------------------------------------------
# Q3. KV-cache wrapper — see phase-09-inference-serving/lab-01-kv-cache for full impl
# ---------------------------------------------------------------------------


# ---------------------------------------------------------------------------
# Q4. BPE encoder (given trained merges)
# ---------------------------------------------------------------------------
def bpe_encode(text: str, merges: list[tuple[str, str]]) -> list[str]:
    """
    Encode `text` using a list of BPE merges (in priority order).
    Each merge: (a, b) means symbol "a" + "b" -> "a b" (concatenation).
    Start with one-char tokens; repeatedly apply the highest-priority merge
    present anywhere in the sequence, until none apply.
    """
    rank = {pair: i for i, pair in enumerate(merges)}
    tokens = list(text)
    while True:
        # Find best (lowest-rank) adjacent pair
        best_pair, best_rank = None, math.inf
        for i in range(len(tokens) - 1):
            r = rank.get((tokens[i], tokens[i + 1]), math.inf)
            if r < best_rank:
                best_rank, best_pair = r, (i, tokens[i] + tokens[i + 1])
        if best_pair is None:
            break
        i, merged = best_pair
        tokens = tokens[:i] + [merged] + tokens[i + 2:]
    return tokens


# ---------------------------------------------------------------------------
# Q5. Top-p (nucleus) sampling
# ---------------------------------------------------------------------------
def top_p_sample(logits: torch.Tensor, p: float = 0.9, temperature: float = 1.0) -> torch.Tensor:
    """
    logits: (V,) — sample one token id.
    """
    logits = logits / max(1e-6, temperature)
    probs = F.softmax(logits, dim=-1)
    sorted_p, sorted_ix = torch.sort(probs, descending=True)
    cumulative = torch.cumsum(sorted_p, dim=0)
    # Keep smallest set with cumulative >= p
    keep = cumulative <= p
    keep[0] = True   # always keep at least the top token
    sorted_p[~keep] = 0
    sorted_p = sorted_p / sorted_p.sum()
    pick = torch.multinomial(sorted_p, 1)
    return sorted_ix[pick]


# ---------------------------------------------------------------------------
# Q6. Beam search (length-normalized)
# ---------------------------------------------------------------------------
def beam_search(step_fn, start_id: int, eos_id: int, beam_width: int = 4,
                max_len: int = 64, alpha: float = 0.7) -> list[int]:
    """
    step_fn(token_ids: list[int]) -> log_probs over vocab (1D tensor)
    Length-normalize per Wu et al. 2016: score / ((5+L)/6)^alpha.
    Returns the best sequence including start, excluding eos.
    """
    Beam = tuple[float, list[int], bool]   # (logp, ids, done)
    beams: list[Beam] = [(0.0, [start_id], False)]

    for _ in range(max_len):
        candidates: list[Beam] = []
        for logp, ids, done in beams:
            if done:
                candidates.append((logp, ids, True)); continue
            log_probs = step_fn(ids)
            top_lp, top_ix = torch.topk(log_probs, beam_width)
            for lp, ix in zip(top_lp.tolist(), top_ix.tolist()):
                new_ids = ids + [ix]
                candidates.append((logp + lp, new_ids, ix == eos_id))

        # Length-normalize for ranking
        def score(b: Beam) -> float:
            logp, ids, _ = b
            L = len(ids)
            return logp / (((5 + L) / 6) ** alpha)

        beams = sorted(candidates, key=score, reverse=True)[:beam_width]
        if all(b[2] for b in beams):
            break

    best = max(beams, key=lambda b: b[0] / (((5 + len(b[1])) / 6) ** alpha))
    return [t for t in best[1] if t != eos_id]


# ---------------------------------------------------------------------------
# Q7. RoPE — apply rotary embeddings to Q and K
# ---------------------------------------------------------------------------
def rope_apply(x: torch.Tensor, base: float = 10000.0) -> torch.Tensor:
    """
    x: (..., T, D) where D is even.
    Rotates each pair (x_{2i}, x_{2i+1}) by angle (pos * theta_i),
    where theta_i = base^{-2i/D}.
    """
    *_, T, D = x.shape
    assert D % 2 == 0
    half = D // 2
    freqs = base ** (-torch.arange(0, half, dtype=torch.float, device=x.device) / half)
    pos = torch.arange(T, dtype=torch.float, device=x.device)
    angles = pos.unsqueeze(1) * freqs.unsqueeze(0)        # (T, half)
    cos, sin = angles.cos(), angles.sin()                 # (T, half) each
    x1, x2 = x[..., 0::2], x[..., 1::2]                   # (..., T, half) each
    rx1 = x1 * cos - x2 * sin
    rx2 = x1 * sin + x2 * cos
    out = torch.stack([rx1, rx2], dim=-1).flatten(-2)     # (..., T, D)
    return out


# ---------------------------------------------------------------------------
# Q8. Token-aware text chunker
# ---------------------------------------------------------------------------
def chunk_tokens(tokens: list[int], max_size: int = 400, overlap: int = 80) -> list[list[int]]:
    if max_size <= overlap:
        raise ValueError("max_size must exceed overlap")
    out, i = [], 0
    while i < len(tokens):
        out.append(tokens[i:i + max_size])
        i += max_size - overlap
    return out


# ---------------------------------------------------------------------------
# Q9. Sliding-window perplexity
# ---------------------------------------------------------------------------
@torch.no_grad()
def sliding_perplexity(model, ids: torch.Tensor, window: int, stride: int) -> float:
    """
    Compute PPL on a long sequence by sliding `window`-sized contexts with
    `stride`. Only the last (window - stride) tokens of each window contribute
    to loss (to avoid double-counting).
    """
    n_ll, n_tok = 0.0, 0
    for start in range(0, ids.size(1), stride):
        end = min(start + window, ids.size(1))
        chunk = ids[:, start:end]
        if chunk.size(1) < 2:
            break
        logits = model(chunk).logits if hasattr(model(chunk), "logits") else model(chunk)
        # CE on the new tokens only
        new_start = max(0, chunk.size(1) - stride - 1)
        log_probs = F.log_softmax(logits[:, new_start:-1], dim=-1)
        target = chunk[:, new_start + 1:]
        ll = log_probs.gather(2, target.unsqueeze(-1)).squeeze(-1)
        n_ll += -ll.sum().item()
        n_tok += target.numel()
        if end == ids.size(1):
            break
    return math.exp(n_ll / n_tok)


# ---------------------------------------------------------------------------
# Q10. Speculative-decoding accept loop
# ---------------------------------------------------------------------------
def speculative_accept(draft_probs: torch.Tensor, target_probs: torch.Tensor,
                       draft_tokens: torch.Tensor) -> tuple[int, torch.Tensor | None]:
    """
    For K draft tokens with their (draft, target) probabilities, run rejection
    sampling. Returns (n_accepted, optional resampled token at the rejection point).

    Per token i: accept with prob min(1, target[i] / draft[i]). On reject, sample
    from the residual distribution max(target - draft, 0), normalized.

    draft_probs, target_probs: (K, V)  — distributions at each draft step
    draft_tokens: (K,)
    """
    K = draft_tokens.size(0)
    for i in range(K):
        t, d = target_probs[i, draft_tokens[i]].item(), draft_probs[i, draft_tokens[i]].item()
        ratio = min(1.0, t / max(d, 1e-12))
        if torch.rand(1).item() <= ratio:
            continue
        # Reject and resample from the residual
        residual = (target_probs[i] - draft_probs[i]).clamp(min=0)
        if residual.sum() == 0:
            return i, None
        residual = residual / residual.sum()
        return i, torch.multinomial(residual, 1)
    return K, None


# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
def _test():
    torch.manual_seed(0)

    # Q1, Q2 — shapes
    x = torch.randn(2, 8, 64)
    out = attention(x, x, x)
    assert out.shape == x.shape, "Q1 shape"
    mha = MultiHeadAttention(64, 4)
    assert mha(x).shape == x.shape, "Q2 shape"

    # Q4 — BPE
    merges = [("a", "b"), ("ab", "c")]
    assert bpe_encode("abcabc", merges) == ["abc", "abc"], f"Q4 got {bpe_encode('abcabc', merges)}"

    # Q5 — top-p returns valid id
    logits = torch.randn(100)
    pick = top_p_sample(logits, p=0.9)
    assert 0 <= pick.item() < 100, "Q5"

    # Q6 — beam search converges to greedy when beam=1
    vocab = 10
    fixed = torch.zeros(vocab); fixed[3] = 1.0
    log_probs_fixed = F.log_softmax(fixed, dim=-1)
    seq = beam_search(lambda ids: log_probs_fixed, start_id=0, eos_id=9, beam_width=1, max_len=5)
    assert seq[1:] == [3, 3, 3, 3, 3], f"Q6 got {seq}"

    # Q7 — RoPE preserves shape and norm
    q = torch.randn(2, 4, 8, 16)
    rq = rope_apply(q)
    assert rq.shape == q.shape, "Q7 shape"
    assert torch.allclose(rq.norm(dim=-1), q.norm(dim=-1), atol=1e-5), "Q7 norm preserved"

    # Q8 — chunker
    chunks = chunk_tokens(list(range(1000)), max_size=400, overlap=80)
    assert all(len(c) <= 400 for c in chunks), "Q8 max"
    assert chunks[1][0] == chunks[0][320], "Q8 overlap"

    # Q10 — accept always when draft == target
    p = torch.zeros(3, 10); p[:, 5] = 1.0
    n, _ = speculative_accept(p.clone(), p.clone(), torch.tensor([5, 5, 5]))
    assert n == 3, "Q10 always accept"

    print("All tests passed ✔")


if __name__ == "__main__":
    _test()
