Lab 01 — Word2Vec Skip-Gram with Negative Sampling (Solution Walkthrough)
Phase: 2 — Classical NLP & Embeddings | Difficulty: ⭐⭐⭐☆☆ | Time: 3–5 hours
Concept primer:
../HITCHHIKERS-GUIDE.md§Word2Vec. This document walks through the code insolution.pyand explains every non-obvious choice.
Run
pip install -r requirements.txt
wget http://mattmahoney.net/dc/text8.zip && unzip text8.zip -d data/
python solution.py --data ./data/text8 --epochs 3
0. The mission
Train a 100-dim word embedding on text8 (a 100 MB cleaned slice of English Wikipedia) using Skip-Gram with Negative Sampling (SGNS). At the end:
nearest("king") → ["queen", "prince", "throne", "kings", "monarch", ...]
nearest("paris") → ["france", "london", "berlin", "vienna", "rome", ...]
…with no labels, just from co-occurrence. The experiment that launched modern NLP (Mikolov et al. 2013).
1. The math
For each (center $c$, context $o$) pair:
$$ \mathcal{L} = -\log \sigma(v_c \cdot v_o) - \sum_{k=1}^{K} \log \sigma(-v_c \cdot v_{n_k}), \quad n_k \sim P_n $$
where $P_n(w) \propto \text{freq}(w)^{0.75}$ is the negative-sampling distribution, and $K$ (5–20) is the number of negatives per positive.
Two embedding tables: an input matrix $V$ for centers, an output matrix $U$ for context/negatives. By convention we keep $V$ as the final embeddings.
2. build_vocab — three things in one
def build_vocab(words, min_count=5):
counts = Counter(words)
vocab = [w for w, c in counts.items() if c >= min_count]
w2i = {w: i for i, w in enumerate(vocab)}
freqs = np.array([counts[w] for w in vocab], dtype=np.float64)
neg_dist = freqs ** 0.75
neg_dist /= neg_dist.sum()
return w2i, vocab, neg_dist
The exponent 0.75 is Mikolov's empirical choice: smaller than 1 down-weights very common words (you don't want every negative to be "the"); larger than 0 doesn't make rare words too likely (which would be uninformative).
min_count=5: drop any word seen <5 times. For text8 (~17M tokens) this prunes ~250k unique words to ~70k. Removes most typos and proper-noun fluff.
3. SkipGramDataset — subsampling and pair generation
3.1 Frequent-word subsampling
self.keep = np.minimum(1.0, np.sqrt(subsample_t / f) + subsample_t / f)
For each center occurrence, probabilistically drop with probability 1 - keep[w]. Why?
"the"appears with frequency ~5%. Without subsampling, half your training pairs would have"the"as the center — useless because"the"co-occurs with everything.- For very common words
f >> t(witht=1e-4), sokeep ≈ sqrt(t/f)≪ 1. - For rare words
f ≪ t, sokeepsaturates at 1 → never dropped.
This trick gives ~2× quality improvement (Mikolov 2013).
3.2 Dynamic window with random shrinking
for i, center in enumerate(self.ids):
if rng.random() > self.keep[center]:
continue
w = rng.randint(1, self.window) # 👈 random window per sample
for j in range(max(0, i - w), min(len(self.ids), i + w + 1)):
if j == i: continue
yield center, self.ids[j]
The window size is resampled per center word. This implicitly weights nearer context words more (they're sampled in every window size; far words only at large window sizes). Mathematically equivalent to a triangular weighting kernel — for free.
IterableDataset (vs Dataset) means we stream pairs instead of materializing all ~100M of them.
4. The model — SkipGramNS
class SkipGramNS(nn.Module):
def __init__(self, vocab_size, dim=100):
super().__init__()
self.in_emb = nn.Embedding(vocab_size, dim)
self.out_emb = nn.Embedding(vocab_size, dim)
nn.init.uniform_(self.in_emb.weight, -0.5/dim, 0.5/dim)
nn.init.zeros_(self.out_emb.weight)
- Two tables, not one. The math fundamentally needs both.
- Init scale
0.5/dim— keeps dot products $v_c \cdot v_o$ in a sensible range early. - Output init
0— at step 0, $v_c \cdot v_o = 0$ → $\sigma(0) = 0.5$ → loss = $\log 2 \approx 0.69$. Clean baseline.
def forward(self, center, pos, neg):
v_c = self.in_emb(center) # (B, D)
v_p = self.out_emb(pos) # (B, D)
v_n = self.out_emb(neg) # (B, K, D)
pos_score = (v_c * v_p).sum(-1)
neg_score = torch.bmm(v_n, v_c.unsqueeze(-1)).squeeze(-1)
loss = -F.logsigmoid(pos_score).mean() - F.logsigmoid(-neg_score).mean()
return loss
(v_c * v_p).sum(-1)is elementwise multiply + sum — the per-row dot product (cheaper thanbmm).bmm(v_n, v_c.unsqueeze(-1))is a batched matrix-vector product:K-many dot products ofv_nagainstv_c.F.logsigmoidnotlog(sigmoid(x))— numerically stable. Naive composition producesnanfor very negativex.
5. collate — batching pairs and sampling negatives
negatives = torch.multinomial(neg_dist_t, len(batch) * n_neg, replacement=True).view(-1, n_neg)
replacement=Trueis essential — without it you'd be sampling without replacement from a 70k-element distribution, hitting the rare tail too often.- We don't filter cases where the negative equals the positive — probability is
≤ 1/|V|≈1/70000, dominated by other negatives.
6. The training loop
opt = torch.optim.Adam(model.parameters(), lr=2.5e-3)
LR is high (~10× a typical transformer LR) because (a) embeddings are linear → no exploding-gradient risk, (b) each parameter is touched rarely (sparse access pattern), so per-update steps must be larger.
batch_size=512, n_neg=5 → each step processes 512 positives + 2560 negatives = 3072 dot products per layer.
7. nearest
W = F.normalize(model.in_emb.weight.detach(), dim=1)
q = W[w2i[word]]
sims = (W @ q).cpu().numpy()
F.normalize(..., dim=1) makes each row unit-norm. Then W @ q is cosine similarity (since cos(a,b) = a_unit · b_unit).
We use the input embedding (in_emb) for query and key. Convention; out_emb works similarly.
8. Expected output
After 3 epochs (~10 min on a 4090, ~30 min on CPU):
chars=70123 tokens=17,005,207
ep 0 step 1000 loss=4.2143
ep 2 step 100000 loss=1.6234
Nearest neighbors:
king [('prince', 0.71), ('queen', 0.69), ('throne', 0.62), ...]
paris [('france', 0.73), ('london', 0.66), ('berlin', 0.62), ...]
computer [('computers', 0.78), ('software', 0.71), ('hardware', 0.66), ...]
Sanity bar: if king's top-5 doesn't include queen, something is wrong — most likely (a) min_count too high, (b) too few epochs, (c) you accidentally averaged input+output before training.
9. The famous analogy test
v = W[w2i["king"]] - W[w2i["man"]] + W[w2i["woman"]]
# nearest to v, excluding king/man/woman → should produce "queen"
This is the demo that made Word2Vec famous. Works because the embedding space encodes gender as a roughly linear direction.
It also fails in revealing ways: try nurse - woman + man and you may get doctor. The bias that motivated debias and counterfactual-augmentation research.
10. Common pitfalls
- Forgetting subsampling → 2× slower convergence, worse quality.
- Same
randomseed across DataLoader workers → all workers yield the same pair sequence. Usenum_workers=0here. log(sigmoid(x))instead ofF.logsigmoid(x)→ NaN losses at high negatives.- Sampling without replacement for negatives → biases toward rare words.
- Only positive pairs (no negatives) → embeddings collapse to one vector.
- Computing similarity without normalizing → returns dot products, correlated with vector norms.
11. Stretch exercises
- Add CBOW (Continuous Bag-of-Words): predict center from average of context. Compare quality.
- Implement GloVe (Pennington 2014): factorize the global co-occurrence matrix's log-counts.
- Visualize with t-SNE/UMAP. Plot 5000 most-frequent words. Observe clusters: countries, days, professions.
- Replicate Levy & Goldberg: SGNS implicitly factorizes the shifted PPMI matrix. Compute SVD of PPMI and compare cosine sims.
- Plug into a downstream task (e.g., SST-2 sentiment). Compare to randomly-initialized embeddings.
- FastText extension: hash character n-grams; sum subword vectors. Handles OOV.
12. What this lab proves about you
You can implement the foundational embedding model without scaffolding, derive the SGNS loss from cross-entropy, explain every hyperparameter, and link it forward to attention (which generalizes "context = nearby tokens" to "context = all tokens with learned weights"). Phase-2 milestone.