Model deployments fail. A newly promoted model might tank accuracy, spike latency, or start returning garbage predictions on edge cases your test suite never covered. If you don’t have automated rollback, you’re stuck watching dashboards at 2 AM hoping someone notices before customers do.

The fix is a health-check loop that continuously monitors your serving model and automatically swaps it back to the last known good version when things go south. Here’s a minimal health check to give you the idea:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import time
from collections import deque

class HealthTracker:
    def __init__(self, window_size: int = 100):
        self.latencies = deque(maxlen=window_size)
        self.errors = deque(maxlen=window_size)

    def record(self, latency_ms: float, is_error: bool):
        self.latencies.append(latency_ms)
        self.errors.append(1 if is_error else 0)

    def error_rate(self) -> float:
        if not self.errors:
            return 0.0
        return sum(self.errors) / len(self.errors)

    def p95_latency(self) -> float:
        if not self.latencies:
            return 0.0
        sorted_lat = sorted(self.latencies)
        idx = int(len(sorted_lat) * 0.95)
        return sorted_lat[min(idx, len(sorted_lat) - 1)]

That’s the core pattern. Track a sliding window of requests, compute aggregate metrics, and trigger rollback when thresholds are breached. Now let’s build the full pipeline.

Building the Model Registry

You need a registry that knows which model versions exist, which one is active, and which was last known good. A database works for production, but a JSON file gets the job done for smaller setups and local testing.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import json
import os
import shutil
from datetime import datetime, timezone
from pathlib import Path
from dataclasses import dataclass, asdict

REGISTRY_DIR = Path("model_registry")
REGISTRY_FILE = REGISTRY_DIR / "registry.json"
MODELS_DIR = REGISTRY_DIR / "models"

@dataclass
class ModelVersion:
    version: str
    path: str
    registered_at: str
    is_active: bool = False
    is_healthy: bool = True

def _load_registry() -> list[dict]:
    if not REGISTRY_FILE.exists():
        return []
    with open(REGISTRY_FILE, "r") as f:
        return json.load(f)

def _save_registry(entries: list[dict]):
    REGISTRY_DIR.mkdir(parents=True, exist_ok=True)
    with open(REGISTRY_FILE, "w") as f:
        json.dump(entries, f, indent=2)

def register_model(version: str, source_path: str) -> ModelVersion:
    """Copy a model file into the registry and track it."""
    MODELS_DIR.mkdir(parents=True, exist_ok=True)
    dest = MODELS_DIR / f"model_{version}.pkl"
    shutil.copy2(source_path, dest)

    entry = ModelVersion(
        version=version,
        path=str(dest),
        registered_at=datetime.now(timezone.utc).isoformat(),
    )

    entries = _load_registry()
    entries.append(asdict(entry))
    _save_registry(entries)
    return entry

def promote_model(version: str) -> ModelVersion:
    """Set a model version as the active deployment."""
    entries = _load_registry()
    target = None
    for entry in entries:
        if entry["version"] == version:
            entry["is_active"] = True
            target = entry
        else:
            entry["is_active"] = False
    if target is None:
        raise ValueError(f"Version {version} not found in registry")
    _save_registry(entries)
    return ModelVersion(**target)

def get_active_model() -> ModelVersion | None:
    entries = _load_registry()
    for entry in entries:
        if entry["is_active"]:
            return ModelVersion(**entry)
    return None

def get_last_healthy_model() -> ModelVersion | None:
    """Find the most recent model that was marked healthy and is not active."""
    entries = _load_registry()
    healthy = [e for e in entries if e["is_healthy"] and not e["is_active"]]
    if not healthy:
        return None
    healthy.sort(key=lambda e: e["registered_at"], reverse=True)
    return ModelVersion(**healthy[0])

def mark_unhealthy(version: str):
    entries = _load_registry()
    for entry in entries:
        if entry["version"] == version:
            entry["is_healthy"] = False
    _save_registry(entries)

def rollback() -> ModelVersion | None:
    """Mark the current model unhealthy and promote the last healthy version."""
    active = get_active_model()
    if active:
        mark_unhealthy(active.version)

    fallback = get_last_healthy_model()
    if fallback:
        return promote_model(fallback.version)
    return None

The key design choice: rollback() marks the failing model as unhealthy before promoting the fallback. That prevents the system from rolling back to the same broken model if the fallback also gets marked unhealthy later. You always move backward through the chain of healthy versions.

Implementing Health Checks

Health checks need to catch three failure modes: the model is too slow, it’s throwing errors, or its predictions have drifted from expected distributions. Here’s a complete health checker that handles all three:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import statistics
from collections import deque
from dataclasses import dataclass, field

@dataclass
class HealthThresholds:
    max_p95_latency_ms: float = 500.0
    max_error_rate: float = 0.05
    max_mean_drift: float = 2.0  # standard deviations from baseline mean

@dataclass
class HealthChecker:
    thresholds: HealthThresholds = field(default_factory=HealthThresholds)
    window_size: int = 200
    latencies: deque = field(default_factory=lambda: deque(maxlen=200))
    errors: deque = field(default_factory=lambda: deque(maxlen=200))
    predictions: deque = field(default_factory=lambda: deque(maxlen=200))
    baseline_mean: float | None = None
    baseline_std: float | None = None

    def set_baseline(self, historical_predictions: list[float]):
        """Set baseline stats from the previous model's prediction distribution."""
        self.baseline_mean = statistics.mean(historical_predictions)
        self.baseline_std = statistics.stdev(historical_predictions) if len(historical_predictions) > 1 else 1.0

    def record_prediction(self, latency_ms: float, is_error: bool, prediction_value: float | None = None):
        self.latencies.append(latency_ms)
        self.errors.append(1 if is_error else 0)
        if prediction_value is not None:
            self.predictions.append(prediction_value)

    def p95_latency(self) -> float:
        if not self.latencies:
            return 0.0
        sorted_lat = sorted(self.latencies)
        idx = int(len(sorted_lat) * 0.95)
        return sorted_lat[min(idx, len(sorted_lat) - 1)]

    def error_rate(self) -> float:
        if not self.errors:
            return 0.0
        return sum(self.errors) / len(self.errors)

    def prediction_drift(self) -> float | None:
        """How many standard deviations the current mean is from baseline."""
        if self.baseline_mean is None or self.baseline_std is None:
            return None
        if len(self.predictions) < 10:
            return None
        current_mean = statistics.mean(self.predictions)
        if self.baseline_std == 0:
            return 0.0
        return abs(current_mean - self.baseline_mean) / self.baseline_std

    def is_healthy(self) -> tuple[bool, list[str]]:
        """Return (healthy, list_of_violations)."""
        violations = []
        min_samples = 20

        if len(self.latencies) >= min_samples:
            p95 = self.p95_latency()
            if p95 > self.thresholds.max_p95_latency_ms:
                violations.append(f"p95 latency {p95:.1f}ms exceeds {self.thresholds.max_p95_latency_ms}ms")

        if len(self.errors) >= min_samples:
            rate = self.error_rate()
            if rate > self.thresholds.max_error_rate:
                violations.append(f"error rate {rate:.2%} exceeds {self.thresholds.max_error_rate:.2%}")

        drift = self.prediction_drift()
        if drift is not None and drift > self.thresholds.max_mean_drift:
            violations.append(f"prediction drift {drift:.2f} std devs exceeds {self.thresholds.max_mean_drift}")

        return (len(violations) == 0, violations)

The min_samples threshold prevents false alarms during warmup. You don’t want to trigger a rollback after 3 requests because 2 of them happened to be slow. Wait until you have enough data to make a statistical judgment.

Prediction drift is the sneaky one. A model can return HTTP 200 with low latency but produce completely wrong values. Tracking mean drift against a baseline catches cases where a model silently degrades – scores that cluster around 0.5 instead of following the expected bimodal distribution, or regression outputs that are off by an order of magnitude.

Building the Rollback API with FastAPI

Now wire everything into a FastAPI app. The lifespan context manager handles model loading at startup and cleanup at shutdown:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import pickle
import time
import logging
from contextlib import asynccontextmanager

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

logger = logging.getLogger("rollback_api")

# Global state
current_model = None
health_checker = HealthChecker(
    thresholds=HealthThresholds(max_p95_latency_ms=500, max_error_rate=0.05, max_mean_drift=2.0)
)

def load_model(path: str):
    with open(path, "rb") as f:
        return pickle.load(f)

def swap_model():
    """Load the currently active model from the registry."""
    global current_model
    active = get_active_model()
    if active is None:
        raise RuntimeError("No active model in registry")
    current_model = load_model(active.path)
    logger.info(f"Loaded model version {active.version} from {active.path}")

@asynccontextmanager
async def lifespan(app: FastAPI):
    swap_model()
    yield
    logger.info("Shutting down, cleaning up model resources")

app = FastAPI(lifespan=lifespan)

class PredictRequest(BaseModel):
    features: list[float]

class PredictResponse(BaseModel):
    prediction: float
    model_version: str
    latency_ms: float

@app.post("/predict", response_model=PredictResponse)
def predict(req: PredictRequest):
    active = get_active_model()
    if active is None or current_model is None:
        raise HTTPException(status_code=503, detail="No model loaded")

    start = time.perf_counter()
    is_error = False
    prediction_value = 0.0

    try:
        prediction_value = float(current_model.predict([req.features])[0])
    except Exception as e:
        is_error = True
        health_checker.record_prediction(
            latency_ms=(time.perf_counter() - start) * 1000,
            is_error=True,
        )
        raise HTTPException(status_code=500, detail=str(e))

    latency_ms = (time.perf_counter() - start) * 1000
    health_checker.record_prediction(latency_ms, is_error, prediction_value)

    return PredictResponse(
        prediction=prediction_value,
        model_version=active.version,
        latency_ms=round(latency_ms, 2),
    )

@app.get("/health")
def health():
    active = get_active_model()
    healthy, violations = health_checker.is_healthy()
    return {
        "healthy": healthy,
        "violations": violations,
        "active_model": active.version if active else None,
        "p95_latency_ms": round(health_checker.p95_latency(), 2),
        "error_rate": round(health_checker.error_rate(), 4),
        "prediction_drift": health_checker.prediction_drift(),
    }

@app.post("/rollback")
def trigger_rollback():
    rolled_back = rollback()
    if rolled_back is None:
        raise HTTPException(status_code=404, detail="No healthy model to roll back to")
    swap_model()
    # Reset the health checker for the new model
    health_checker.latencies.clear()
    health_checker.errors.clear()
    health_checker.predictions.clear()
    return {"rolled_back_to": rolled_back.version, "path": rolled_back.path}

Notice that swap_model() is called both at startup and after rollback. The /rollback endpoint also clears the health checker’s sliding windows so the new model gets a clean evaluation period. Without that reset, stale bad metrics from the previous model would immediately trigger another rollback.

Automated Rollback with a Monitoring Loop

Manual rollback endpoints are nice, but the real value is a background loop that watches the health checker and triggers rollback without human intervention. Use a background thread so it works alongside the async FastAPI server:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import threading

def monitoring_loop(check_interval: float = 10.0):
    """Background thread that checks model health and triggers rollback."""
    logger.info(f"Monitoring loop started, checking every {check_interval}s")

    while True:
        time.sleep(check_interval)

        healthy, violations = health_checker.is_healthy()
        if not healthy:
            logger.warning(f"Model health check failed: {violations}")
            active = get_active_model()
            if active:
                logger.warning(f"Triggering automatic rollback from version {active.version}")

            rolled_back = rollback()
            if rolled_back:
                try:
                    swap_model()
                    health_checker.latencies.clear()
                    health_checker.errors.clear()
                    health_checker.predictions.clear()
                    logger.info(f"Automatic rollback complete. Now serving {rolled_back.version}")
                except Exception as e:
                    logger.error(f"Failed to load rollback model: {e}")
            else:
                logger.error("No healthy model available for rollback. Manual intervention required.")

@asynccontextmanager
async def lifespan(app: FastAPI):
    swap_model()
    monitor_thread = threading.Thread(target=monitoring_loop, kwargs={"check_interval": 10.0}, daemon=True)
    monitor_thread.start()
    yield
    logger.info("Shutting down")

app = FastAPI(lifespan=lifespan)

Setting daemon=True on the thread means it dies when the main process exits – no orphaned threads. The check_interval is 10 seconds here, but tune it for your traffic volume. High-traffic services can check every 2-3 seconds. Low-traffic services might need 30-60 seconds to accumulate enough samples in the sliding window.

One thing to watch: the monitoring loop and the /rollback endpoint can race. If someone hits the manual rollback endpoint while the loop is mid-check, you might double-rollback. For production use, wrap the rollback logic in a threading.Lock:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
rollback_lock = threading.Lock()

def safe_rollback() -> ModelVersion | None:
    with rollback_lock:
        rolled_back = rollback()
        if rolled_back:
            swap_model()
            health_checker.latencies.clear()
            health_checker.errors.clear()
            health_checker.predictions.clear()
        return rolled_back

Call safe_rollback() from both the endpoint and the monitoring loop instead of duplicating the logic.

Common Errors and Fixes

RuntimeError: No active model in registry on startup

You started the server before promoting any model. Register and promote at least one version first:

1
2
register_model("v1", "path/to/trained_model.pkl")
promote_model("v1")

Then start the FastAPI app. The lifespan handler calls swap_model() which reads the active model from the registry.

Rollback keeps triggering in a loop

This happens when you only have one model version registered. The current model gets marked unhealthy, get_last_healthy_model() finds nothing, and the next cycle sees no active model at all. Always keep at least two versions in the registry. A better safeguard is to add a cooldown period after rollback:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
last_rollback_time = 0.0
ROLLBACK_COOLDOWN = 120.0  # seconds

def monitoring_loop(check_interval: float = 10.0):
    global last_rollback_time
    while True:
        time.sleep(check_interval)
        healthy, violations = health_checker.is_healthy()
        if not healthy:
            elapsed = time.time() - last_rollback_time
            if elapsed < ROLLBACK_COOLDOWN:
                logger.warning(f"Skipping rollback, cooldown has {ROLLBACK_COOLDOWN - elapsed:.0f}s remaining")
                continue
            last_rollback_time = time.time()
            safe_rollback()

Prediction drift triggers false positives after retraining

When you retrain on new data, the prediction distribution legitimately changes. You need to update the baseline. After promoting a new model and confirming it’s performing well, capture its prediction stats as the new baseline:

1
2
3
# After collecting enough predictions from the new model
recent_predictions = list(health_checker.predictions)
health_checker.set_baseline(recent_predictions)

Or better yet, compute the baseline offline from your validation set and set it during model registration.