Lab 7-02: FastAPI Inference Server
Learning Goals
- Build a production-ready REST API for ML model inference
- Implement async request handling and dynamic batching
- Add health checks, Prometheus metrics, and request logging
- Handle concurrent requests without blocking the GPU
Core Concepts
Why FastAPI?
FastAPI is the standard for Python ML serving because:
- Native async support via
asyncio - Automatic OpenAPI/Swagger docs
- Pydantic validation for request/response schemas
- 2-3× faster than Flask for concurrent workloads
Dynamic Batching
GPU utilization is maximized by batching requests together. The key tradeoff:
Latency ↑ (wait for batch to fill)
vs
Throughput ↑ (process more requests per second)
Strategy: Collect requests for a configurable window (e.g., 10ms), then process as one batch.
import asyncio
batch_queue: asyncio.Queue = asyncio.Queue()
async def batch_processor():
while True:
batch, futures = [], []
# Collect requests for max 10ms or until batch is full
deadline = asyncio.get_event_loop().time() + 0.010
while len(batch) < MAX_BATCH and asyncio.get_event_loop().time() < deadline:
try:
item, future = await asyncio.wait_for(
batch_queue.get(), timeout=max(0, deadline - asyncio.get_event_loop().time())
)
batch.append(item)
futures.append(future)
except asyncio.TimeoutError:
break
if batch:
results = model_inference(torch.stack(batch))
for result, future in zip(results, futures):
future.set_result(result)
Request/Response Schemas
from pydantic import BaseModel
import base64
class PredictRequest(BaseModel):
image_b64: str # Base64-encoded image
confidence_threshold: float = 0.5
class Detection(BaseModel):
label: str
confidence: float
bbox: list[float] # [x1, y1, x2, y2]
class PredictResponse(BaseModel):
detections: list[Detection]
inference_ms: float
Prometheus Metrics
from prometheus_client import Counter, Histogram, Gauge, generate_latest
REQUEST_COUNT = Counter("inference_requests_total", "Total inference requests")
LATENCY = Histogram("inference_latency_seconds", "Inference latency",
buckets=[.005, .01, .025, .05, .1, .25, .5, 1])
BATCH_SIZE = Histogram("inference_batch_size", "Batch sizes processed",
buckets=[1, 2, 4, 8, 16, 32])
@app.get("/metrics")
async def metrics():
return Response(generate_latest(), media_type="text/plain")
Interview Questions
Q: How do you handle a slow model that takes 500ms per request?
A: Use a background worker pool with a queue. Requests post to the queue and poll for results. This prevents blocking and allows concurrency. Alternatively, use Celery + Redis for distributed task queues.
Q: What's the difference between async def and def in FastAPI?
A: async def handlers are run in the async event loop — good for I/O-bound work. def handlers run in a thread pool — FastAPI handles this automatically. For CPU-bound inference, use def or offload to a ProcessPoolExecutor to avoid blocking the event loop.
Q: How do you prevent OOM on the GPU server?
A: Cap concurrent requests with a asyncio.Semaphore(MAX_CONCURRENT=4). Also limit input image size and batch size. Add an /health check that monitors GPU memory usage.