Lab 6-01: Vision Transformer (ViT) from Scratch
Learning Objectives
- Understand patch embedding: images are sequences of patches
- Implement positional encodings (learned 1D)
- Build a full Transformer encoder block (MHSA + FFN + LayerNorm)
- Train ViT on synthetic data
- Visualize attention maps (which patches the CLS token attends to)
- Understand ViT vs CNN inductive biases
ViT Architecture
Image (H×W×C)
│
▼ Patch Embedding: split into N patches, linear projection
[P₁, P₂, ..., Pₙ] ← shape: (N, D)
│
▼ Prepend [CLS] token, add positional embeddings
[CLS, P₁+pos₁, P₂+pos₂, ..., Pₙ+posₙ] ← shape: (N+1, D)
│
▼ L × Transformer Encoder Block:
│ ┌─────────────────────────────────────┐
│ │ x = x + MHSA(LayerNorm(x)) │ (pre-norm, residual)
│ │ x = x + FFN(LayerNorm(x)) │
│ └─────────────────────────────────────┘
│
▼ Extract CLS token → MLP Head
Prediction (n_classes)
Patch Embedding
# Split image into patches and project to embedding dimension D
class PatchEmbedding(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
n_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size)
# Conv2d with kernel=stride=patch_size = non-overlapping patch extraction
def forward(self, x): # x: (B, C, H, W)
x = self.proj(x) # (B, D, H/P, W/P)
x = x.flatten(2) # (B, D, N)
x = x.transpose(1, 2) # (B, N, D)
return x
Interview Questions
Q: What is the main inductive bias difference between CNNs and ViTs?
A: CNNs have two strong inductive biases baked in: (1) locality — conv filters only look at local neighbourhoods, (2) translation equivariance — the same filter is applied everywhere. ViTs have neither — attention is global from the start (every patch can attend to every other). This means ViTs need much more data to learn spatial structure from scratch, but can model long-range dependencies CNNs struggle with.
Q: Why is the CLS token used for classification instead of average pooling?
A: The CLS token is a learnable token prepended to the sequence. Through self-attention over L layers, it aggregates information from all patch tokens. It's a design choice from BERT. Average pooling over all patch tokens also works (used in DeiT), sometimes better with sufficient data.
Q: What is the computational complexity of self-attention and why does it matter for high-res images?
A: $O(N^2 \cdot D)$ where N = number of patches. For 224×224 with 16×16 patches: N=196, manageable. For 1024×1024 with 16×16 patches: N=4096, attention matrix is 4096×4096 — 64M entries per head. This is why hierarchical approaches (Swin Transformer, window attention) are used for dense prediction tasks on high-resolution images.