Lab 01 — Word2Vec Skip-Gram with Negative Sampling (Solution Walkthrough)

Phase: 2 — Classical NLP & Embeddings | Difficulty: ⭐⭐⭐☆☆ | Time: 3–5 hours

Concept primer: ../HITCHHIKERS-GUIDE.md §Word2Vec. This document walks through the code in solution.py and explains every non-obvious choice.

Run

pip install -r requirements.txt
wget http://mattmahoney.net/dc/text8.zip && unzip text8.zip -d data/
python solution.py --data ./data/text8 --epochs 3

0. The mission

Train a 100-dim word embedding on text8 (a 100 MB cleaned slice of English Wikipedia) using Skip-Gram with Negative Sampling (SGNS). At the end:

nearest("king")  → ["queen", "prince", "throne", "kings", "monarch", ...]
nearest("paris") → ["france", "london", "berlin", "vienna", "rome", ...]

…with no labels, just from co-occurrence. The experiment that launched modern NLP (Mikolov et al. 2013).


1. The math

For each (center $c$, context $o$) pair:

$$ \mathcal{L} = -\log \sigma(v_c \cdot v_o) - \sum_{k=1}^{K} \log \sigma(-v_c \cdot v_{n_k}), \quad n_k \sim P_n $$

where $P_n(w) \propto \text{freq}(w)^{0.75}$ is the negative-sampling distribution, and $K$ (5–20) is the number of negatives per positive.

Two embedding tables: an input matrix $V$ for centers, an output matrix $U$ for context/negatives. By convention we keep $V$ as the final embeddings.


2. build_vocab — three things in one

def build_vocab(words, min_count=5):
    counts = Counter(words)
    vocab = [w for w, c in counts.items() if c >= min_count]
    w2i = {w: i for i, w in enumerate(vocab)}
    freqs = np.array([counts[w] for w in vocab], dtype=np.float64)
    neg_dist = freqs ** 0.75
    neg_dist /= neg_dist.sum()
    return w2i, vocab, neg_dist

The exponent 0.75 is Mikolov's empirical choice: smaller than 1 down-weights very common words (you don't want every negative to be "the"); larger than 0 doesn't make rare words too likely (which would be uninformative).

min_count=5: drop any word seen <5 times. For text8 (~17M tokens) this prunes ~250k unique words to ~70k. Removes most typos and proper-noun fluff.


3. SkipGramDataset — subsampling and pair generation

3.1 Frequent-word subsampling

self.keep = np.minimum(1.0, np.sqrt(subsample_t / f) + subsample_t / f)

For each center occurrence, probabilistically drop with probability 1 - keep[w]. Why?

  • "the" appears with frequency ~5%. Without subsampling, half your training pairs would have "the" as the center — useless because "the" co-occurs with everything.
  • For very common words f >> t (with t=1e-4), so keep ≈ sqrt(t/f) ≪ 1.
  • For rare words f ≪ t, so keep saturates at 1 → never dropped.

This trick gives ~2× quality improvement (Mikolov 2013).

3.2 Dynamic window with random shrinking

for i, center in enumerate(self.ids):
    if rng.random() > self.keep[center]:
        continue
    w = rng.randint(1, self.window)   # 👈 random window per sample
    for j in range(max(0, i - w), min(len(self.ids), i + w + 1)):
        if j == i: continue
        yield center, self.ids[j]

The window size is resampled per center word. This implicitly weights nearer context words more (they're sampled in every window size; far words only at large window sizes). Mathematically equivalent to a triangular weighting kernel — for free.

IterableDataset (vs Dataset) means we stream pairs instead of materializing all ~100M of them.


4. The model — SkipGramNS

class SkipGramNS(nn.Module):
    def __init__(self, vocab_size, dim=100):
        super().__init__()
        self.in_emb = nn.Embedding(vocab_size, dim)
        self.out_emb = nn.Embedding(vocab_size, dim)
        nn.init.uniform_(self.in_emb.weight, -0.5/dim, 0.5/dim)
        nn.init.zeros_(self.out_emb.weight)
  • Two tables, not one. The math fundamentally needs both.
  • Init scale 0.5/dim — keeps dot products $v_c \cdot v_o$ in a sensible range early.
  • Output init 0 — at step 0, $v_c \cdot v_o = 0$ → $\sigma(0) = 0.5$ → loss = $\log 2 \approx 0.69$. Clean baseline.
def forward(self, center, pos, neg):
    v_c = self.in_emb(center)              # (B, D)
    v_p = self.out_emb(pos)                # (B, D)
    v_n = self.out_emb(neg)                # (B, K, D)
    pos_score = (v_c * v_p).sum(-1)
    neg_score = torch.bmm(v_n, v_c.unsqueeze(-1)).squeeze(-1)
    loss = -F.logsigmoid(pos_score).mean() - F.logsigmoid(-neg_score).mean()
    return loss
  • (v_c * v_p).sum(-1) is elementwise multiply + sum — the per-row dot product (cheaper than bmm).
  • bmm(v_n, v_c.unsqueeze(-1)) is a batched matrix-vector product: K-many dot products of v_n against v_c.
  • F.logsigmoid not log(sigmoid(x)) — numerically stable. Naive composition produces nan for very negative x.

5. collate — batching pairs and sampling negatives

negatives = torch.multinomial(neg_dist_t, len(batch) * n_neg, replacement=True).view(-1, n_neg)
  • replacement=True is essential — without it you'd be sampling without replacement from a 70k-element distribution, hitting the rare tail too often.
  • We don't filter cases where the negative equals the positive — probability is ≤ 1/|V|1/70000, dominated by other negatives.

6. The training loop

opt = torch.optim.Adam(model.parameters(), lr=2.5e-3)

LR is high (~10× a typical transformer LR) because (a) embeddings are linear → no exploding-gradient risk, (b) each parameter is touched rarely (sparse access pattern), so per-update steps must be larger.

batch_size=512, n_neg=5 → each step processes 512 positives + 2560 negatives = 3072 dot products per layer.


7. nearest

W = F.normalize(model.in_emb.weight.detach(), dim=1)
q = W[w2i[word]]
sims = (W @ q).cpu().numpy()

F.normalize(..., dim=1) makes each row unit-norm. Then W @ q is cosine similarity (since cos(a,b) = a_unit · b_unit).

We use the input embedding (in_emb) for query and key. Convention; out_emb works similarly.


8. Expected output

After 3 epochs (~10 min on a 4090, ~30 min on CPU):

chars=70123  tokens=17,005,207
  ep 0 step    1000  loss=4.2143
  ep 2 step  100000  loss=1.6234

Nearest neighbors:
  king        [('prince', 0.71), ('queen', 0.69), ('throne', 0.62), ...]
  paris       [('france', 0.73), ('london', 0.66), ('berlin', 0.62), ...]
  computer    [('computers', 0.78), ('software', 0.71), ('hardware', 0.66), ...]

Sanity bar: if king's top-5 doesn't include queen, something is wrong — most likely (a) min_count too high, (b) too few epochs, (c) you accidentally averaged input+output before training.


9. The famous analogy test

v = W[w2i["king"]] - W[w2i["man"]] + W[w2i["woman"]]
# nearest to v, excluding king/man/woman → should produce "queen"

This is the demo that made Word2Vec famous. Works because the embedding space encodes gender as a roughly linear direction.

It also fails in revealing ways: try nurse - woman + man and you may get doctor. The bias that motivated debias and counterfactual-augmentation research.


10. Common pitfalls

  1. Forgetting subsampling → 2× slower convergence, worse quality.
  2. Same random seed across DataLoader workers → all workers yield the same pair sequence. Use num_workers=0 here.
  3. log(sigmoid(x)) instead of F.logsigmoid(x) → NaN losses at high negatives.
  4. Sampling without replacement for negatives → biases toward rare words.
  5. Only positive pairs (no negatives) → embeddings collapse to one vector.
  6. Computing similarity without normalizing → returns dot products, correlated with vector norms.

11. Stretch exercises

  • Add CBOW (Continuous Bag-of-Words): predict center from average of context. Compare quality.
  • Implement GloVe (Pennington 2014): factorize the global co-occurrence matrix's log-counts.
  • Visualize with t-SNE/UMAP. Plot 5000 most-frequent words. Observe clusters: countries, days, professions.
  • Replicate Levy & Goldberg: SGNS implicitly factorizes the shifted PPMI matrix. Compute SVD of PPMI and compare cosine sims.
  • Plug into a downstream task (e.g., SST-2 sentiment). Compare to randomly-initialized embeddings.
  • FastText extension: hash character n-grams; sum subword vectors. Handles OOV.

12. What this lab proves about you

You can implement the foundational embedding model without scaffolding, derive the SGNS loss from cross-entropy, explain every hyperparameter, and link it forward to attention (which generalizes "context = nearby tokens" to "context = all tokens with learned weights"). Phase-2 milestone.