"""
Lab 04 — Mini Transformer (Decoder-Only) From Scratch
~200 lines. Reused in Phase 5 (nanoGPT) and Phase 9 (KV-cache).
"""
from __future__ import annotations
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F


@dataclass
class GPTConfig:
    vocab_size: int = 50257
    n_layer: int = 6
    n_head: int = 8
    d_model: int = 512
    d_ff: int = 2048           # typically 4 * d_model
    block_size: int = 1024     # max context length
    dropout: float = 0.0
    tie_weights: bool = True


class CausalSelfAttention(nn.Module):
    def __init__(self, cfg: GPTConfig):
        super().__init__()
        assert cfg.d_model % cfg.n_head == 0
        self.n_head = cfg.n_head
        self.d_head = cfg.d_model // cfg.n_head
        # Single fused projection: (Q, K, V) all at once
        self.qkv = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
        self.proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
        self.attn_drop = nn.Dropout(cfg.dropout)
        self.resid_drop = nn.Dropout(cfg.dropout)
        # Pre-computed causal mask
        self.register_buffer(
            "mask",
            torch.tril(torch.ones(cfg.block_size, cfg.block_size, dtype=torch.bool))
                 .view(1, 1, cfg.block_size, cfg.block_size),
            persistent=False,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape
        qkv = self.qkv(x)                                              # (B, T, 3C)
        q, k, v = qkv.split(C, dim=-1)
        # Reshape to (B, n_head, T, d_head)
        q = q.view(B, T, self.n_head, self.d_head).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.d_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.d_head).transpose(1, 2)

        # Scaled dot product
        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)       # (B, n_head, T, T)
        att = att.masked_fill(~self.mask[:, :, :T, :T], float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)

        y = att @ v                                                     # (B, n_head, T, d_head)
        y = y.transpose(1, 2).contiguous().view(B, T, C)                # merge heads
        return self.resid_drop(self.proj(y))


class MLP(nn.Module):
    def __init__(self, cfg: GPTConfig):
        super().__init__()
        self.fc = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
        self.proj = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
        self.drop = nn.Dropout(cfg.dropout)

    def forward(self, x):
        return self.drop(self.proj(F.gelu(self.fc(x))))


class Block(nn.Module):
    """Pre-norm transformer block."""
    def __init__(self, cfg: GPTConfig):
        super().__init__()
        self.ln1 = nn.LayerNorm(cfg.d_model)
        self.attn = CausalSelfAttention(cfg)
        self.ln2 = nn.LayerNorm(cfg.d_model)
        self.mlp = MLP(cfg)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


class MiniGPT(nn.Module):
    def __init__(self, cfg: GPTConfig):
        super().__init__()
        self.cfg = cfg
        self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
        self.pos_emb = nn.Embedding(cfg.block_size, cfg.d_model)
        self.drop = nn.Dropout(cfg.dropout)
        self.blocks = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layer)])
        self.ln_f = nn.LayerNorm(cfg.d_model)
        self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
        if cfg.tie_weights:
            self.head.weight = self.tok_emb.weight

        # Init like GPT-2
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Embedding):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)

    def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None):
        B, T = idx.shape
        assert T <= self.cfg.block_size
        pos = torch.arange(T, device=idx.device)
        x = self.tok_emb(idx) + self.pos_emb(pos)        # (B, T, C)
        x = self.drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.ln_f(x)
        logits = self.head(x)                             # (B, T, V)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
            )
        return logits, loss

    @torch.no_grad()
    def generate(self, idx: torch.Tensor, max_new_tokens: int,
                 temperature: float = 1.0, top_k: int | None = None) -> torch.Tensor:
        for _ in range(max_new_tokens):
            ctx = idx[:, -self.cfg.block_size:]
            logits, _ = self(ctx)
            logits = logits[:, -1, :] / max(1e-6, temperature)
            if top_k is not None:
                v, _ = torch.topk(logits, top_k)
                logits[logits < v[:, [-1]]] = float("-inf")
            probs = F.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, 1)
            idx = torch.cat([idx, next_id], dim=1)
        return idx

    def num_params(self) -> int:
        n = sum(p.numel() for p in self.parameters())
        if self.cfg.tie_weights:
            n -= self.head.weight.numel()  # avoid double-counting tied weights
        return n


# ---------------------------------------------------------------------------
# Sanity tests — these are the standard "is my transformer wired correctly?"
# checks that every from-scratch implementation should pass.
# ---------------------------------------------------------------------------
def sanity_init_loss(cfg: GPTConfig):
    """Fresh model should have loss ≈ log(vocab_size) on random data."""
    model = MiniGPT(cfg)
    x = torch.randint(0, cfg.vocab_size, (4, 64))
    y = torch.randint(0, cfg.vocab_size, (4, 64))
    _, loss = model(x, y)
    expected = math.log(cfg.vocab_size)
    print(f"[init-loss]  got={loss.item():.4f}  expected≈{expected:.4f}  ok={abs(loss.item() - expected) < 0.5}")


def sanity_overfit_one_batch(cfg: GPTConfig, steps: int = 200):
    """A correctly wired transformer can drive loss → 0 on a single batch."""
    model = MiniGPT(cfg)
    x = torch.randint(0, cfg.vocab_size, (2, 32))
    y = torch.randint(0, cfg.vocab_size, (2, 32))
    opt = torch.optim.AdamW(model.parameters(), lr=3e-3)
    for s in range(steps):
        _, loss = model(x, y)
        opt.zero_grad(); loss.backward(); opt.step()
        if s % 50 == 0:
            print(f"  step {s:4d}  loss={loss.item():.4f}")
    print(f"[overfit]   final_loss={loss.item():.4f}  ok={loss.item() < 0.5}")


if __name__ == "__main__":
    cfg = GPTConfig(
        vocab_size=1000, n_layer=2, n_head=4, d_model=128, d_ff=512, block_size=128,
    )
    model = MiniGPT(cfg)
    print(f"params = {model.num_params():,}")
    print()
    sanity_init_loss(cfg)
    print()
    sanity_overfit_one_batch(cfg, steps=200)
