🛸 Hitchhiker's Guide — Phase 3: RNNs and Language Modeling

Read this if: You want to internalize why every modern LLM is a "language model", what perplexity means, and where the conceptual bridges are between an RNN and a Transformer. RNNs are not in production for new LLMs in 2026 (transformers and SSMs replaced them) — but their failure modes are exactly what attention was invented to fix, so understanding them sharpens transformer intuition immensely.


0. The 30-second mental model

A language model is a probability distribution over sequences:

$$ P(w_1, w_2, \ldots, w_T) = \prod_{t=1}^T P(w_t \mid w_1, \ldots, w_{t-1}) $$

A neural language model parameterizes that conditional with a network. An RNN maintains a recurrent hidden state h_t = f(h_{t-1}, x_t) that's supposed to summarize all prior tokens; an LSTM does the same with gates that protect against vanishing gradients; a transformer throws the recurrence away and lets every token attend to every other token in parallel. Same task, three architectures.

By the end of Phase 3 you should:

  • Know what an n-gram baseline gives you and why it's the floor for any LM evaluation.
  • Be able to derive Backpropagation Through Time on the whiteboard.
  • Explain vanishing/exploding gradients and how LSTM gates fix them.
  • Compute and interpret perplexity, bits-per-character, and bits-per-byte.
  • Implement a character-level RNN from raw cells (no nn.RNN) and use it to generate Shakespearean text.

1. Language modeling as a discipline

1.1 Why predict the next token?

Because everything is the next token. Translation, summarization, code generation, chat — they're all "given some prefix, what comes next?" If a model assigns high probability to true continuations across a vast and diverse corpus, it has implicitly learned grammar, facts, reasoning patterns, style, and code structure. This is the core hypothesis on which every LLM stands.

1.2 The chain rule and the autoregressive factorization

Any joint distribution over a sequence factorizes as a product of conditionals (chain rule of probability). A model that computes P(w_t | w_<t) for every t is sufficient to:

  • Score any sequence (just multiply).
  • Sample from the model (sample one token at a time, append, repeat).

That's the autoregressive style. There are non-AR alternatives (BERT-style masked LM, diffusion LMs, SSMs) but AR has won for generation.

1.3 Cross-entropy = next-token loss

Training a language model means minimizing the negative log-likelihood of the true next token at every position:

$$ \mathcal{L} = -\sum_t \log P(w_t \mid w_{<t}; \theta) $$

Equivalently: cross-entropy between the model's distribution and the one-hot distribution at the true token. This is the loss. Pretraining, fine-tuning, distillation all start from this.


2. n-gram models — your baseline

Before deep learning, language models were tables of conditional probabilities:

$$ P(w_t \mid w_{t-n+1}, \ldots, w_{t-1}) = \frac{\text{count}(w_{t-n+1}, \ldots, w_t)}{\text{count}(w_{t-n+1}, \ldots, w_{t-1})} $$

For unseen n-grams: smoothing (add-1, Kneser-Ney). Kneser-Ney is the gold-standard pre-deep-learning smoothing. Read Jurafsky & Martin Ch. 3.

A 5-gram Kneser-Ney model on 1B words gets ~80 perplexity on PTB. A modern transformer LM gets ~10–20. Always include the n-gram baseline before claiming your model is good.

Reference: Jurafsky & Martin, Speech and Language Processing, 3rd ed., Ch. 3 (free draft).


3. The Recurrent Neural Network

3.1 The vanilla RNN cell

$$ h_t = \tanh(W_{xh} x_t + W_{hh} h_{t-1} + b_h) $$

That's it. The hidden state h_t is a fixed-size vector (e.g., 256 dims) that's supposed to summarize all prior tokens. The output prediction is a softmax over W_{hy} h_t + b_y.

Because the same W_{hh} is applied at every step, the network has a fixed parameter count regardless of sequence length. That's beautiful — and dooms it.

3.2 Backpropagation Through Time (BPTT)

To train, "unroll" the recurrence into a deep feed-forward network of length T. Apply standard backprop. The gradient of the loss with respect to h_0 involves a product:

$$ \frac{\partial \mathcal{L}}{\partial h_0} \propto \prod_{t=1}^T \frac{\partial h_t}{\partial h_{t-1}} = \prod_{t=1}^T W_{hh}^\top , \text{diag}(\tanh'(\cdot)) $$

This is a long product of matrices.

  • If the spectral radius of W_{hh} < 1 (and tanh' ≤ 1), the product vanishes. The model can't learn long-range dependencies.
  • If > 1, the product explodes.

Both are catastrophic. Vanishing is the more common problem. Exploding is mitigated cheaply by gradient clipping (torch.nn.utils.clip_grad_norm_).

For long sequences: truncated BPTT — backprop only through the last K steps; detach the hidden state across boundaries.

3.3 LSTM — gating to the rescue

Hochreiter & Schmidhuber (1997). Add a cell state c_t that flows through with mostly identity-like updates, controlled by three gates (forget f, input i, output o):

$$ \begin{aligned} f_t &= \sigma(W_f [x_t, h_{t-1}] + b_f) \ i_t &= \sigma(W_i [x_t, h_{t-1}] + b_i) \ o_t &= \sigma(W_o [x_t, h_{t-1}] + b_o) \ g_t &= \tanh(W_g [x_t, h_{t-1}] + b_g) \ c_t &= f_t \odot c_{t-1} + i_t \odot g_t \ h_t &= o_t \odot \tanh(c_t) \end{aligned} $$

Why it works: the cell state c_t is updated additively (c_{t-1} + ...), so the gradient through c is roughly the identity matrix times the forget gate. If the forget gate is near 1, gradients flow through hundreds of steps without vanishing.

3.4 GRU — fewer gates

Cho et al. (2014). Merges forget+input into a single gate. Slightly fewer params; usually comparable to LSTM in practice.

3.5 Stacking and bidirectionality

  • Stacked: feed h_t^{(1)} of layer 1 as input to layer 2. Each layer learns higher-level features. Beyond ~3 layers, returns diminish.
  • Bidirectional: a forward RNN + a backward RNN; concatenate. Useful for tagging/classification but not for autoregressive generation (you can't see the future at inference).

3.6 Why RNNs lost to Transformers

IssueRNNTransformer
ParallelismNone — must process tokens sequentiallyFull — all positions in parallel during training
Long-range dependenciesHard (vanishing)Easy (direct attention)
Ease of scalingPoorExcellent
Inference speedO(T) sequentiallyO(1) per token (with KV cache) but O(T²) per token without
Memory at long contextO(1) hidden stateO(T) KV cache

The last row is interesting — RNNs have constant memory at inference, which is why State Space Models (Mamba, S5, Hyena) are mounting a comeback for very long contexts. A modern RNN literacy still matters.


4. Perplexity and friends

4.1 Perplexity

$$ \text{PPL} = \exp\left(\frac{1}{N} \sum_{i=1}^N -\log P(w_i \mid w_{<i})\right) = \exp(\bar{\mathcal{L}}) $$

Intuition: "if the model treated every step as a uniform choice over PPL options, it would have the same loss." Lower is better. PPL = vocab_size means random; PPL = 1 means perfect.

PPL is not comparable across tokenizers — a model with a 50k subword vocab cannot be PPL-compared to a model with a 30k vocab. To compare across tokenizers use:

4.2 Bits-per-character (BPC) / Bits-per-byte (BPB)

$$ \text{BPB} = \frac{\text{loss in nats} \cdot \log_2 e}{\text{number of bytes in the original text}} $$

Because bytes are tokenizer-agnostic, BPB lets you fairly compare any LM. State-of-the-art LMs on enwik8 reach ~0.94 BPB.

4.3 What "good" perplexity looks like

  • 5-gram Kneser-Ney on PTB: ~80 PPL.
  • Char-RNN on Tiny Shakespeare (small): ~5–10 PPL (chars are easier per-step).
  • GPT-2 small on WikiText-103: ~30 PPL.
  • GPT-3 175B on PTB: ~20 PPL.
  • Frontier LLMs on web text: ~6–10 PPL on held-out web.

5. Sampling from a language model

You'll meet these again in Phase 9. Preview:

  • Greedy (argmax): deterministic; can repeat.
  • Beam search: keep top-k partial sequences. Better for translation; rare in chat (boring outputs).
  • Temperature: divide logits by T. T < 1 sharpens, T > 1 flattens.
  • Top-k: sample only from the k most-likely tokens.
  • Top-p (nucleus): sample from the smallest set whose cumulative prob ≥ p. Adapts to entropy.
  • Repetition penalty / no-repeat n-gram: hacks to prevent loops.

6. Lab 01 walkthrough (lab-01-char-rnn)

6.1 What you'll build

  • A VanillaRNNCell — implemented as the raw tanh(Wxh x + Whh h) math, not nn.RNN. The point is to see autograd handle BPTT.
  • A CharRNN module — embedding → stacked RNN cells → linear projection to vocab.
  • A train() loop that processes Tiny Shakespeare in fixed-length sequences, with TBPTT (detach() the hidden state between batches).
  • A sample() method that generates new text given a seed string.

6.2 Things to internalize while reading the solution

  • Why detach() between batches? Without it, autograd builds an infinitely long graph and OOMs. Detaching pretends the prior hidden state is a constant input.
  • Why is the loss reshaped to (B*T, V) for cross-entropy? Because F.cross_entropy expects a 2D logits tensor and a 1D target tensor. The (B, T) structure is irrelevant to the per-position loss.
  • Why no causal mask? Because RNNs are causal by construction — h_t only depends on h_{<t}.
  • Why stack the cells but not parallelize them? Each layer must wait for the previous layer's output at the same time step. Sequence dimension is sequential; layer dimension can be batched in a single for loop with shared compute pattern.

6.3 Watch the loss curve

Early training: loss drops fast as the model learns the unigram distribution. After a few hundred steps it learns bigram statistics, then short word fragments, then real words, then word ordering. By 5k steps, it should produce something that looks like Shakespeare-flavored gibberish. By 20k+, full pseudo-grammatical lines. (Famous Karpathy 2015 blog post.)


7. References

  • Karpathy, The Unreasonable Effectiveness of Recurrent Neural Networks (2015) — required reading.
  • Olah, Understanding LSTM Networks (2015) — required reading; the diagrams.
  • Hochreiter & Schmidhuber (1997), Long Short-Term Memory.
  • Cho et al. (2014), Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation — GRU.
  • Sutskever, Vinyals, Le (2014), Sequence to Sequence Learning with Neural Networks — the seq2seq paper.
  • Bahdanau, Cho, Bengio (2015), Neural Machine Translation by Jointly Learning to Align and Translate — the attention paper that started everything Phase 4 covers.
  • Jurafsky & Martin, Speech and Language Processing, 3rd ed., Ch. 9 (RNNs and LSTMs).
  • Deep Learning (Goodfellow, Bengio, Courville), Ch. 10.
  • Pascanu, Mikolov, Bengio (2013), On the difficulty of training recurrent neural networks — the vanishing/exploding gradient analysis.

8. Common interview questions on Phase 3 material

  1. Walk me through BPTT on a 3-step RNN.
  2. What causes vanishing gradients in vanilla RNNs and how do LSTMs help?
  3. Compute perplexity from a cross-entropy loss of 2.3 nats per token.
  4. Why is BPB more honest than PPL across tokenizers?
  5. What's a Kneser-Ney 5-gram baseline and when is it competitive?
  6. Why didn't RNNs scale to GPT-3 sizes?
  7. What's truncated BPTT and why do we need it?
  8. Compare LSTM vs GRU.
  9. Why are state-space models (Mamba) suddenly interesting again?
  10. Implement an LSTM cell on a whiteboard.

9. From solid → exceptional

  • Reimplement an LSTM cell from scratch (no nn.LSTMCell) and train on Tiny Shakespeare. Compare loss curves and sample quality vs vanilla RNN.
  • Reproduce Karpathy's char-RNN results on Linux source code; show the model learns to balance braces and indent.
  • Implement a GRU alongside; benchmark perplexity at equal parameter count.
  • Train a 1-layer LSTM on enwik8; compute BPB; compare to the famous IndyLSTM / mLSTM numbers (~1.0 BPB).
  • Read the original attention paper (Bahdanau 2015) and implement attention as an add-on to a seq2seq RNN encoder-decoder. This gives you the conceptual bridge to Phase 4.
  • Skim the Mamba paper (Gu & Dao, 2023) and write a one-page comparison: how is Mamba different from an LSTM?

DayActivity
MonKarpathy RNN blog + Olah LSTM blog
TueRead Jurafsky & Martin Ch. 3 (n-grams) and Ch. 9 (RNN/LSTM)
WedLab 01 — implement char-RNN, get it to train
ThuSample at multiple temperatures; tune until output is interesting
FriImplement LSTM cell extension; compare
SatRead Bahdanau 2015 (attention preview)
SunMock interview yourself on the 10 questions; write BPTT derivation in a notebook