Lab 6-01: Vision Transformer (ViT) from Scratch

Learning Objectives

  • Understand patch embedding: images are sequences of patches
  • Implement positional encodings (learned 1D)
  • Build a full Transformer encoder block (MHSA + FFN + LayerNorm)
  • Train ViT on synthetic data
  • Visualize attention maps (which patches the CLS token attends to)
  • Understand ViT vs CNN inductive biases

ViT Architecture

Image (H×W×C)
    │
    ▼ Patch Embedding: split into N patches, linear projection
[P₁, P₂, ..., Pₙ] ← shape: (N, D)
    │
    ▼ Prepend [CLS] token, add positional embeddings
[CLS, P₁+pos₁, P₂+pos₂, ..., Pₙ+posₙ] ← shape: (N+1, D)
    │
    ▼ L × Transformer Encoder Block:
    │     ┌─────────────────────────────────────┐
    │     │ x = x + MHSA(LayerNorm(x))          │  (pre-norm, residual)
    │     │ x = x + FFN(LayerNorm(x))           │
    │     └─────────────────────────────────────┘
    │
    ▼ Extract CLS token → MLP Head
Prediction (n_classes)

Patch Embedding

# Split image into patches and project to embedding dimension D
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim,
                              kernel_size=patch_size, stride=patch_size)
        # Conv2d with kernel=stride=patch_size = non-overlapping patch extraction

    def forward(self, x):  # x: (B, C, H, W)
        x = self.proj(x)            # (B, D, H/P, W/P)
        x = x.flatten(2)           # (B, D, N)
        x = x.transpose(1, 2)     # (B, N, D)
        return x

Interview Questions

Q: What is the main inductive bias difference between CNNs and ViTs?
A: CNNs have two strong inductive biases baked in: (1) locality — conv filters only look at local neighbourhoods, (2) translation equivariance — the same filter is applied everywhere. ViTs have neither — attention is global from the start (every patch can attend to every other). This means ViTs need much more data to learn spatial structure from scratch, but can model long-range dependencies CNNs struggle with.

Q: Why is the CLS token used for classification instead of average pooling?
A: The CLS token is a learnable token prepended to the sequence. Through self-attention over L layers, it aggregates information from all patch tokens. It's a design choice from BERT. Average pooling over all patch tokens also works (used in DeiT), sometimes better with sufficient data.

Q: What is the computational complexity of self-attention and why does it matter for high-res images?
A: $O(N^2 \cdot D)$ where N = number of patches. For 224×224 with 16×16 patches: N=196, manageable. For 1024×1024 with 16×16 patches: N=4096, attention matrix is 4096×4096 — 64M entries per head. This is why hierarchical approaches (Swin Transformer, window attention) are used for dense prediction tasks on high-resolution images.