Lab 02 — Training Loop Best Practices
Phase 3: PyTorch | Week 7-8
A good training loop is the difference between a model that diverges and one that trains reliably. These patterns appear in every production codebase.
Learning Objectives
- Build a production-grade training loop with validation
- Implement early stopping, gradient clipping, and checkpointing
- Compare LR schedulers: StepLR, CosineAnnealingLR, OneCycleLR
- Use Automatic Mixed Precision (AMP) correctly on GPU
- Debug training instability with gradient norm monitoring
Theory
The Complete Training Loop
for epoch in range(n_epochs):
model.train()
for batch in train_loader:
optimizer.zero_grad(set_to_none=True) # slightly faster than zero
with autocast(device_type='cuda'): # AMP
loss = criterion(model(x), y)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
scheduler.step()
model.eval()
with torch.no_grad():
val_loss = evaluate(model, val_loader)
Automatic Mixed Precision (AMP)
FP32: 32-bit floats — full precision, more memory.
FP16: 16-bit floats — 2× smaller, Tensor Core acceleration (16× faster on A100).
Loss scaling: FP16 has small dynamic range (~$10^{-4}$ to $10^4$). Gradients can underflow to 0. Scale loss by large factor $S$, then divide gradients by $S$ before update.
PyTorch GradScaler handles this automatically. Dynamic scaling: halves $S$ on overflow, doubles $S$ every 2000 steps.
BF16: Brain Float 16 — same exponent range as FP32 but fewer mantissa bits. No loss scaling needed. Preferred on A100/H100.
Gradient Clipping
Prevents exploding gradients (common in RNNs, deep networks):
$$\text{if} |\nabla| > \text{max_norm}: \quad \nabla \leftarrow \nabla \cdot \frac{\text{max_norm}}{|\nabla|}$$
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
Always clip after scaler.unscale_() and before scaler.step().
Learning Rate Schedulers
| Scheduler | Behavior | Best For |
|---|---|---|
| StepLR | Decay by $\gamma$ every $k$ epochs | Simple baselines |
| CosineAnnealingLR | Cosine decay to $\eta_{min}$ | ResNets, most CNNs |
| OneCycleLR | Warmup → peak → cosine decay (1 cycle) | Fast training (less epochs) |
| ReduceLROnPlateau | Reduce LR when metric plateaus | When you don't know n_epochs |
| WarmupCosine | Linear warmup + cosine | Transformers |
OneCycleLR formula: LR rises linearly from $\eta_{min}$ to $\eta_{max}$ for first 30% of training, then decays via cosine anneal.
Early Stopping
class EarlyStopping:
def __init__(self, patience=10, min_delta=1e-4):
self.patience = patience
self.counter = 0
self.best_loss = float('inf')
def __call__(self, val_loss) -> bool:
if val_loss < self.best_loss - self.min_delta:
self.best_loss = val_loss
self.counter = 0
return False # continue
self.counter += 1
return self.counter >= self.patience # stop
What the Lab Covers
| Section | Content |
|---|---|
SyntheticImageDataset | Custom Dataset + DataLoader + pin_memory |
SimpleCNN | 3-block CNN with BatchNorm |
EarlyStopping | Patience-based stopping |
train_one_epoch() | AMP + GradScaler + gradient clipping |
lr_scheduler_comparison() | Plot 4 schedulers side-by-side |
checkpoint_demo() | Save/load model + optimizer state |
Interview Questions
Q: Why zero_grad(set_to_none=True) instead of zero_grad()?
A: Setting to None avoids writing zeros to memory, which is slightly faster and saves memory when using optimizer state. Functionally identical for standard training.
Q: Why does gradient clipping go between unscale_ and step_?
A: GradScaler.unscale_() divides gradients by the scale factor, restoring their true magnitudes. You must clip the true gradients, not the scaled ones. Otherwise, your clip threshold is meaningless.
Q: When should you use OneCycleLR vs CosineAnnealingLR? A: OneCycleLR is best when you know the total number of steps and want fastest convergence (fewer epochs). CosineAnnealingLR is better when training is more exploratory or you want to restart training.
Q: What causes NaN loss during training?
A: (1) Learning rate too high. (2) Log of 0 or division by 0 in loss. (3) FP16 overflow without loss scaling. (4) Bad data (inf/nan in input). Always add assert not torch.isnan(loss) early in debugging.
Run
pip install -r requirements.txt
python solution.py
# Outputs saved to outputs/