Lab 7-02: FastAPI Inference Server

Learning Goals

  • Build a production-ready REST API for ML model inference
  • Implement async request handling and dynamic batching
  • Add health checks, Prometheus metrics, and request logging
  • Handle concurrent requests without blocking the GPU

Core Concepts

Why FastAPI?

FastAPI is the standard for Python ML serving because:

  • Native async support via asyncio
  • Automatic OpenAPI/Swagger docs
  • Pydantic validation for request/response schemas
  • 2-3× faster than Flask for concurrent workloads

Dynamic Batching

GPU utilization is maximized by batching requests together. The key tradeoff:

Latency ↑ (wait for batch to fill)
    vs
Throughput ↑ (process more requests per second)

Strategy: Collect requests for a configurable window (e.g., 10ms), then process as one batch.

import asyncio

batch_queue: asyncio.Queue = asyncio.Queue()

async def batch_processor():
    while True:
        batch, futures = [], []
        # Collect requests for max 10ms or until batch is full
        deadline = asyncio.get_event_loop().time() + 0.010
        while len(batch) < MAX_BATCH and asyncio.get_event_loop().time() < deadline:
            try:
                item, future = await asyncio.wait_for(
                    batch_queue.get(), timeout=max(0, deadline - asyncio.get_event_loop().time())
                )
                batch.append(item)
                futures.append(future)
            except asyncio.TimeoutError:
                break
        
        if batch:
            results = model_inference(torch.stack(batch))
            for result, future in zip(results, futures):
                future.set_result(result)

Request/Response Schemas

from pydantic import BaseModel
import base64

class PredictRequest(BaseModel):
    image_b64: str           # Base64-encoded image
    confidence_threshold: float = 0.5

class Detection(BaseModel):
    label: str
    confidence: float
    bbox: list[float]        # [x1, y1, x2, y2]

class PredictResponse(BaseModel):
    detections: list[Detection]
    inference_ms: float

Prometheus Metrics

from prometheus_client import Counter, Histogram, Gauge, generate_latest

REQUEST_COUNT = Counter("inference_requests_total", "Total inference requests")
LATENCY = Histogram("inference_latency_seconds", "Inference latency",
                    buckets=[.005, .01, .025, .05, .1, .25, .5, 1])
BATCH_SIZE = Histogram("inference_batch_size", "Batch sizes processed",
                       buckets=[1, 2, 4, 8, 16, 32])

@app.get("/metrics")
async def metrics():
    return Response(generate_latest(), media_type="text/plain")

Interview Questions

Q: How do you handle a slow model that takes 500ms per request?
A: Use a background worker pool with a queue. Requests post to the queue and poll for results. This prevents blocking and allows concurrency. Alternatively, use Celery + Redis for distributed task queues.

Q: What's the difference between async def and def in FastAPI?
A: async def handlers are run in the async event loop — good for I/O-bound work. def handlers run in a thread pool — FastAPI handles this automatically. For CPU-bound inference, use def or offload to a ProcessPoolExecutor to avoid blocking the event loop.

Q: How do you prevent OOM on the GPU server?
A: Cap concurrent requests with a asyncio.Semaphore(MAX_CONCURRENT=4). Also limit input image size and batch size. Add an /health check that monitors GPU memory usage.