GPU, TPU & AI Accelerator Architecture

"Pick the right hardware for the job" is a system design competency that separates senior from mid-level CV engineers.


The Memory Hierarchy Problem

GPUs are bandwidth-bound, not compute-bound for most CV workloads. The bottleneck is moving data between:

┌──────────────────────────────────────────────────────────┐
│  Host (CPU) DRAM     ~50 GB/s  (PCIe 4.0 ×16)          │
│      ↕ PCIe                                              │
│  GPU HBM (VRAM)     ~2 TB/s  (A100: 2TB/s, H100: 3.3TB/s)│
│      ↕                                                   │
│  L2 Cache            ~5 TB/s                            │
│      ↕                                                   │
│  L1/Shared Mem      ~20 TB/s                            │
│      ↕                                                   │
│  Registers           ~80 TB/s                           │
└──────────────────────────────────────────────────────────┘

Key insight: Minimize CPU↔GPU data transfers. Keep data resident on GPU across operations.


CUDA Programming Model

Thread Hierarchy

Grid
└── Block (max 1024 threads)
    └── Thread
  • Warp: 32 threads that execute in lockstep (SIMT). Divergent branches (if/else) cause warp divergence — half the warp is idle.
  • Occupancy: ratio of active warps to maximum possible. Higher occupancy hides memory latency.
  • Shared memory: 48–96 KB per SM, acts as programmer-controlled L1 cache. Critical for tiled matrix multiplication.

Memory Types

MemoryScopeLifetimeSpeed
RegisterThreadKernelFastest
SharedBlockKernel~20 TB/s
L1/L2 CacheSM / GPUKernelAuto-managed
Global (HBM)All threadsApplication~2 TB/s
Pinned (host)CPUApplication~50 GB/s (zero-copy capable)
UnifiedCPU+GPUApplicationSlower (page faults)

PyTorch CUDA Best Practices

# ✅ Pin memory for faster CPU→GPU transfer
loader = DataLoader(dataset, pin_memory=True, num_workers=4)

# ✅ Non-blocking transfer (overlaps with compute)
x = x.to(device, non_blocking=True)

# ✅ Mixed precision: uses Tensor Cores (2-4× throughput)
with torch.cuda.amp.autocast():
    output = model(input)

# ✅ Torch.compile (PyTorch 2.0): fuses ops, reduces kernel launches
model = torch.compile(model)  # ~1.5-3× speedup on A100

# ✅ Profile to find actual bottleneck
with torch.profiler.profile(
    activities=[torch.profiler.ProfilerActivity.CPU,
                torch.profiler.ProfilerActivity.CUDA],
    with_stack=True
) as prof:
    model(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

# ❌ Never do this — creates a new CUDA context copy
for batch in loader:
    model(batch.cuda())  # if model is on CPU — silent correctness bug

TensorRT: Production GPU Inference

TensorRT is NVIDIA's inference optimizer. Converts a trained model into an optimized engine:

Optimization Steps

  1. Graph fusion: Fuse Conv+BN+ReLU into a single kernel (fewer memory round-trips)
  2. Precision calibration: FP32 → FP16 or INT8 with minimal accuracy loss
  3. Kernel auto-tuning: Benchmarks multiple CUDA kernel implementations, picks fastest for your GPU
  4. Layer/tensor fusion: Reduce memory allocation overhead

Precision vs Speed (A100 SXM):

PrecisionTensor Core TFLOPSMemoryUse Case
FP3219.5100%Training, debugging
TF32156100%Default PyTorch training on A100
FP1631250%Training (AMP), inference
BF1631250%Training (more stable than FP16)
INT862425%Deployment inference
INT4124812.5%LLM serving (emerging)

TensorRT Python (ONNX export path)

import torch
import onnx
import tensorrt as trt

# Step 1: Export to ONNX
model.eval()
dummy_input = torch.randn(1, 3, 640, 640, device='cuda')
torch.onnx.export(
    model, dummy_input, "model.onnx",
    input_names=['images'], output_names=['output'],
    dynamic_axes={'images': {0: 'batch'}, 'output': {0: 'batch'}},
    opset_version=17
)

# Step 2: Build TensorRT engine
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(TRT_LOGGER)
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 2 << 30)  # 2GB
config.set_flag(trt.BuilderFlag.FP16)  # Enable FP16

network = builder.create_network(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
parser = trt.OnnxParser(network, TRT_LOGGER)
with open("model.onnx", 'rb') as f:
    parser.parse(f.read())

engine = builder.build_serialized_network(network, config)
with open("model.trt", 'wb') as f:
    f.write(engine)

# Step 3: Inference is ~2-4× faster than PyTorch eager mode

TPU Architecture (Google Cloud)

TPUs are designed specifically for matrix multiply (the dominant operation in deep learning).

MXU (Matrix Multiply Unit)

The TPU v4 contains 4 chips, each with:

  • 2 MXUs: 128×128 systolic arrays (each can do 32,768 multiplications per cycle)
  • HBM: 32 GB per chip
  • Interconnect: High-bandwidth ICI for multi-chip (pod) setups

Why systolic array? Data flows through a grid of processing elements — each element does one multiply-accumulate. Data reuse is built into hardware, eliminating bandwidth bottleneck for large matrix multiplies.

TPU vs GPU for CV

AspectGPU (A100)TPU v4
FlexibilityHigh (arbitrary CUDA ops)Low (XLA compiler must handle)
Custom opsEasyHard (must be XLA-compatible)
Memory80 GB32 GB/chip
Multi-deviceNVLink/PCIeICI fabric (seamless)
Best forResearch, inferenceLarge-scale training (LLMs, ViT)
Cost (cloud)$3-8/hr$2-6/chip/hr
FrameworkPyTorch/TF/JAXJAX (best), TF, PyTorch/XLA

JAX on TPU

import jax
import jax.numpy as jnp
from jax import jit, vmap, grad

# Functional, immutable — perfect for TPU's stateless execution model
@jit  # compile with XLA → fast on TPU
def forward(params, x):
    return jnp.dot(x, params['W']) + params['b']

# vmap: vectorize over batch dimension without explicit loops
batched_forward = vmap(forward, in_axes=(None, 0))

# grad: automatic differentiation (functional, no .backward())
grad_fn = grad(lambda p, x, y: jnp.mean((batched_forward(p, x) - y)**2))

# pmap: data-parallel over multiple TPU cores
parallel_forward = jax.pmap(forward)

When to Choose TPU

  • Training Vision Transformers (ViT), BERT-scale models
  • Large-batch training where GPU memory limits batch size
  • When using JAX/Flax (native TPU framework)
  • NOT recommended: models with dynamic shapes, complex custom CUDA ops

Hardware Selection Guide

Inference: Latency vs Throughput

Latency requirements:
< 10ms   → GPU (A10G, T4) with TensorRT + FP16
10-100ms → GPU or CPU (depends on model size)
> 100ms  → CPU may be sufficient (saves cost)

Throughput requirements:
> 1000 req/s → GPU cluster with batching (Triton Inference Server)
              → Consider NVIDIA A100 with batch_size=64+

Training Hardware

Dataset size:
Small (< 100k images):   Single RTX 4090 (24GB, consumer GPU)
Medium (< 1M images):    Single A100 (80GB) or 4× A6000
Large (> 10M images):    Multi-GPU DDP (8× A100) or TPU pod

Real-world CV System Hardware Stack

Edge (camera):      NVIDIA Jetson AGX Orin (275 TOPS, 64GB unified memory)
On-premise:         4× A100 80GB SXM + NVLink (for training)
Cloud inference:    AWS g4dn.xlarge (T4 GPU, $0.53/hr) with auto-scaling
Cloud training:     AWS p4d.24xlarge (8× A100, $32/hr)

Interview Questions

Q: A team wants to deploy a YOLOv8 model that runs at 30ms on A100 in PyTorch. The customer needs 10ms. What would you do?

A: I'd attack this in order of impact:

  1. TensorRT conversion with FP16: typically 2-4× speedup → might reach 8-15ms
  2. INT8 quantization if accuracy permits: another 1.5-2× on top of FP16
  3. Input resolution reduction: YOLOv8 at 416px vs 640px is ~2× faster
  4. torch.compile if staying in PyTorch: ~20-40% speedup with minimal effort
  5. Model distillation: train a smaller student model (YOLOv8n vs v8x)
  6. Hardware upgrade: T4→A10G→A100 — not a code change, but immediate
  7. Batching: if the use case allows, batch multiple frames together

Q: Explain the difference between DataParallel and DistributedDataParallel.

A: DataParallel (DP) uses one process, one Python GIL, replicates the model to N GPUs, splits the batch, runs forward on each, gathers outputs to GPU 0 for loss computation, then scatters gradients. Problems: (1) GIL bottleneck — Python threads can't truly parallelize, (2) GPU 0 is the gathering bottleneck — it sees more load than others (load imbalance), (3) memory overhead from gathered activations on GPU 0.

DistributedDataParallel (DDP) spawns one process per GPU. Each process has its own model replica, optimizer, and data loader. After each backward pass, gradients are synchronized via AllReduce (ring-allreduce in NCCL). No single GPU bottleneck. Scales linearly. DDP is always preferred for multi-GPU training — DP is legacy.

Q: Why is mixed precision training numerically unstable, and how does GradScaler fix it?

A: FP16 has a limited dynamic range (~6×10⁻⁸ to 65,504). Gradients during training are often very small (especially early epochs or with small learning rates) and can underflow to zero in FP16 — this is called gradient underflow. GradScaler multiplies the loss by a large scale factor (e.g., 2¹⁰) before backward, so gradients are in FP16's representable range. After backward, it unscales the gradients before the optimizer step, and checks for inf/NaN. If found, it skips the optimizer step and reduces the scale factor. If not found for many steps, it increases the scale factor. The forward pass (activations, weights) stays in FP16 for speed; master weights are kept in FP32 for accuracy.