Model deployments fail. A newly promoted model might tank accuracy, spike latency, or start returning garbage predictions on edge cases your test suite never covered. If you don’t have automated rollback, you’re stuck watching dashboards at 2 AM hoping someone notices before customers do.
The fix is a health-check loop that continuously monitors your serving model and automatically swaps it back to the last known good version when things go south. Here’s a minimal health check to give you the idea:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
| import time
from collections import deque
class HealthTracker:
def __init__(self, window_size: int = 100):
self.latencies = deque(maxlen=window_size)
self.errors = deque(maxlen=window_size)
def record(self, latency_ms: float, is_error: bool):
self.latencies.append(latency_ms)
self.errors.append(1 if is_error else 0)
def error_rate(self) -> float:
if not self.errors:
return 0.0
return sum(self.errors) / len(self.errors)
def p95_latency(self) -> float:
if not self.latencies:
return 0.0
sorted_lat = sorted(self.latencies)
idx = int(len(sorted_lat) * 0.95)
return sorted_lat[min(idx, len(sorted_lat) - 1)]
|
That’s the core pattern. Track a sliding window of requests, compute aggregate metrics, and trigger rollback when thresholds are breached. Now let’s build the full pipeline.
Building the Model Registry#
You need a registry that knows which model versions exist, which one is active, and which was last known good. A database works for production, but a JSON file gets the job done for smaller setups and local testing.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
| import json
import os
import shutil
from datetime import datetime, timezone
from pathlib import Path
from dataclasses import dataclass, asdict
REGISTRY_DIR = Path("model_registry")
REGISTRY_FILE = REGISTRY_DIR / "registry.json"
MODELS_DIR = REGISTRY_DIR / "models"
@dataclass
class ModelVersion:
version: str
path: str
registered_at: str
is_active: bool = False
is_healthy: bool = True
def _load_registry() -> list[dict]:
if not REGISTRY_FILE.exists():
return []
with open(REGISTRY_FILE, "r") as f:
return json.load(f)
def _save_registry(entries: list[dict]):
REGISTRY_DIR.mkdir(parents=True, exist_ok=True)
with open(REGISTRY_FILE, "w") as f:
json.dump(entries, f, indent=2)
def register_model(version: str, source_path: str) -> ModelVersion:
"""Copy a model file into the registry and track it."""
MODELS_DIR.mkdir(parents=True, exist_ok=True)
dest = MODELS_DIR / f"model_{version}.pkl"
shutil.copy2(source_path, dest)
entry = ModelVersion(
version=version,
path=str(dest),
registered_at=datetime.now(timezone.utc).isoformat(),
)
entries = _load_registry()
entries.append(asdict(entry))
_save_registry(entries)
return entry
def promote_model(version: str) -> ModelVersion:
"""Set a model version as the active deployment."""
entries = _load_registry()
target = None
for entry in entries:
if entry["version"] == version:
entry["is_active"] = True
target = entry
else:
entry["is_active"] = False
if target is None:
raise ValueError(f"Version {version} not found in registry")
_save_registry(entries)
return ModelVersion(**target)
def get_active_model() -> ModelVersion | None:
entries = _load_registry()
for entry in entries:
if entry["is_active"]:
return ModelVersion(**entry)
return None
def get_last_healthy_model() -> ModelVersion | None:
"""Find the most recent model that was marked healthy and is not active."""
entries = _load_registry()
healthy = [e for e in entries if e["is_healthy"] and not e["is_active"]]
if not healthy:
return None
healthy.sort(key=lambda e: e["registered_at"], reverse=True)
return ModelVersion(**healthy[0])
def mark_unhealthy(version: str):
entries = _load_registry()
for entry in entries:
if entry["version"] == version:
entry["is_healthy"] = False
_save_registry(entries)
def rollback() -> ModelVersion | None:
"""Mark the current model unhealthy and promote the last healthy version."""
active = get_active_model()
if active:
mark_unhealthy(active.version)
fallback = get_last_healthy_model()
if fallback:
return promote_model(fallback.version)
return None
|
The key design choice: rollback() marks the failing model as unhealthy before promoting the fallback. That prevents the system from rolling back to the same broken model if the fallback also gets marked unhealthy later. You always move backward through the chain of healthy versions.
Implementing Health Checks#
Health checks need to catch three failure modes: the model is too slow, it’s throwing errors, or its predictions have drifted from expected distributions. Here’s a complete health checker that handles all three:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
| import statistics
from collections import deque
from dataclasses import dataclass, field
@dataclass
class HealthThresholds:
max_p95_latency_ms: float = 500.0
max_error_rate: float = 0.05
max_mean_drift: float = 2.0 # standard deviations from baseline mean
@dataclass
class HealthChecker:
thresholds: HealthThresholds = field(default_factory=HealthThresholds)
window_size: int = 200
latencies: deque = field(default_factory=lambda: deque(maxlen=200))
errors: deque = field(default_factory=lambda: deque(maxlen=200))
predictions: deque = field(default_factory=lambda: deque(maxlen=200))
baseline_mean: float | None = None
baseline_std: float | None = None
def set_baseline(self, historical_predictions: list[float]):
"""Set baseline stats from the previous model's prediction distribution."""
self.baseline_mean = statistics.mean(historical_predictions)
self.baseline_std = statistics.stdev(historical_predictions) if len(historical_predictions) > 1 else 1.0
def record_prediction(self, latency_ms: float, is_error: bool, prediction_value: float | None = None):
self.latencies.append(latency_ms)
self.errors.append(1 if is_error else 0)
if prediction_value is not None:
self.predictions.append(prediction_value)
def p95_latency(self) -> float:
if not self.latencies:
return 0.0
sorted_lat = sorted(self.latencies)
idx = int(len(sorted_lat) * 0.95)
return sorted_lat[min(idx, len(sorted_lat) - 1)]
def error_rate(self) -> float:
if not self.errors:
return 0.0
return sum(self.errors) / len(self.errors)
def prediction_drift(self) -> float | None:
"""How many standard deviations the current mean is from baseline."""
if self.baseline_mean is None or self.baseline_std is None:
return None
if len(self.predictions) < 10:
return None
current_mean = statistics.mean(self.predictions)
if self.baseline_std == 0:
return 0.0
return abs(current_mean - self.baseline_mean) / self.baseline_std
def is_healthy(self) -> tuple[bool, list[str]]:
"""Return (healthy, list_of_violations)."""
violations = []
min_samples = 20
if len(self.latencies) >= min_samples:
p95 = self.p95_latency()
if p95 > self.thresholds.max_p95_latency_ms:
violations.append(f"p95 latency {p95:.1f}ms exceeds {self.thresholds.max_p95_latency_ms}ms")
if len(self.errors) >= min_samples:
rate = self.error_rate()
if rate > self.thresholds.max_error_rate:
violations.append(f"error rate {rate:.2%} exceeds {self.thresholds.max_error_rate:.2%}")
drift = self.prediction_drift()
if drift is not None and drift > self.thresholds.max_mean_drift:
violations.append(f"prediction drift {drift:.2f} std devs exceeds {self.thresholds.max_mean_drift}")
return (len(violations) == 0, violations)
|
The min_samples threshold prevents false alarms during warmup. You don’t want to trigger a rollback after 3 requests because 2 of them happened to be slow. Wait until you have enough data to make a statistical judgment.
Prediction drift is the sneaky one. A model can return HTTP 200 with low latency but produce completely wrong values. Tracking mean drift against a baseline catches cases where a model silently degrades – scores that cluster around 0.5 instead of following the expected bimodal distribution, or regression outputs that are off by an order of magnitude.
Building the Rollback API with FastAPI#
Now wire everything into a FastAPI app. The lifespan context manager handles model loading at startup and cleanup at shutdown:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
| import pickle
import time
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
logger = logging.getLogger("rollback_api")
# Global state
current_model = None
health_checker = HealthChecker(
thresholds=HealthThresholds(max_p95_latency_ms=500, max_error_rate=0.05, max_mean_drift=2.0)
)
def load_model(path: str):
with open(path, "rb") as f:
return pickle.load(f)
def swap_model():
"""Load the currently active model from the registry."""
global current_model
active = get_active_model()
if active is None:
raise RuntimeError("No active model in registry")
current_model = load_model(active.path)
logger.info(f"Loaded model version {active.version} from {active.path}")
@asynccontextmanager
async def lifespan(app: FastAPI):
swap_model()
yield
logger.info("Shutting down, cleaning up model resources")
app = FastAPI(lifespan=lifespan)
class PredictRequest(BaseModel):
features: list[float]
class PredictResponse(BaseModel):
prediction: float
model_version: str
latency_ms: float
@app.post("/predict", response_model=PredictResponse)
def predict(req: PredictRequest):
active = get_active_model()
if active is None or current_model is None:
raise HTTPException(status_code=503, detail="No model loaded")
start = time.perf_counter()
is_error = False
prediction_value = 0.0
try:
prediction_value = float(current_model.predict([req.features])[0])
except Exception as e:
is_error = True
health_checker.record_prediction(
latency_ms=(time.perf_counter() - start) * 1000,
is_error=True,
)
raise HTTPException(status_code=500, detail=str(e))
latency_ms = (time.perf_counter() - start) * 1000
health_checker.record_prediction(latency_ms, is_error, prediction_value)
return PredictResponse(
prediction=prediction_value,
model_version=active.version,
latency_ms=round(latency_ms, 2),
)
@app.get("/health")
def health():
active = get_active_model()
healthy, violations = health_checker.is_healthy()
return {
"healthy": healthy,
"violations": violations,
"active_model": active.version if active else None,
"p95_latency_ms": round(health_checker.p95_latency(), 2),
"error_rate": round(health_checker.error_rate(), 4),
"prediction_drift": health_checker.prediction_drift(),
}
@app.post("/rollback")
def trigger_rollback():
rolled_back = rollback()
if rolled_back is None:
raise HTTPException(status_code=404, detail="No healthy model to roll back to")
swap_model()
# Reset the health checker for the new model
health_checker.latencies.clear()
health_checker.errors.clear()
health_checker.predictions.clear()
return {"rolled_back_to": rolled_back.version, "path": rolled_back.path}
|
Notice that swap_model() is called both at startup and after rollback. The /rollback endpoint also clears the health checker’s sliding windows so the new model gets a clean evaluation period. Without that reset, stale bad metrics from the previous model would immediately trigger another rollback.
Automated Rollback with a Monitoring Loop#
Manual rollback endpoints are nice, but the real value is a background loop that watches the health checker and triggers rollback without human intervention. Use a background thread so it works alongside the async FastAPI server:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
| import threading
def monitoring_loop(check_interval: float = 10.0):
"""Background thread that checks model health and triggers rollback."""
logger.info(f"Monitoring loop started, checking every {check_interval}s")
while True:
time.sleep(check_interval)
healthy, violations = health_checker.is_healthy()
if not healthy:
logger.warning(f"Model health check failed: {violations}")
active = get_active_model()
if active:
logger.warning(f"Triggering automatic rollback from version {active.version}")
rolled_back = rollback()
if rolled_back:
try:
swap_model()
health_checker.latencies.clear()
health_checker.errors.clear()
health_checker.predictions.clear()
logger.info(f"Automatic rollback complete. Now serving {rolled_back.version}")
except Exception as e:
logger.error(f"Failed to load rollback model: {e}")
else:
logger.error("No healthy model available for rollback. Manual intervention required.")
@asynccontextmanager
async def lifespan(app: FastAPI):
swap_model()
monitor_thread = threading.Thread(target=monitoring_loop, kwargs={"check_interval": 10.0}, daemon=True)
monitor_thread.start()
yield
logger.info("Shutting down")
app = FastAPI(lifespan=lifespan)
|
Setting daemon=True on the thread means it dies when the main process exits – no orphaned threads. The check_interval is 10 seconds here, but tune it for your traffic volume. High-traffic services can check every 2-3 seconds. Low-traffic services might need 30-60 seconds to accumulate enough samples in the sliding window.
One thing to watch: the monitoring loop and the /rollback endpoint can race. If someone hits the manual rollback endpoint while the loop is mid-check, you might double-rollback. For production use, wrap the rollback logic in a threading.Lock:
1
2
3
4
5
6
7
8
9
10
11
| rollback_lock = threading.Lock()
def safe_rollback() -> ModelVersion | None:
with rollback_lock:
rolled_back = rollback()
if rolled_back:
swap_model()
health_checker.latencies.clear()
health_checker.errors.clear()
health_checker.predictions.clear()
return rolled_back
|
Call safe_rollback() from both the endpoint and the monitoring loop instead of duplicating the logic.
Common Errors and Fixes#
RuntimeError: No active model in registry on startup
You started the server before promoting any model. Register and promote at least one version first:
1
2
| register_model("v1", "path/to/trained_model.pkl")
promote_model("v1")
|
Then start the FastAPI app. The lifespan handler calls swap_model() which reads the active model from the registry.
Rollback keeps triggering in a loop
This happens when you only have one model version registered. The current model gets marked unhealthy, get_last_healthy_model() finds nothing, and the next cycle sees no active model at all. Always keep at least two versions in the registry. A better safeguard is to add a cooldown period after rollback:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
| last_rollback_time = 0.0
ROLLBACK_COOLDOWN = 120.0 # seconds
def monitoring_loop(check_interval: float = 10.0):
global last_rollback_time
while True:
time.sleep(check_interval)
healthy, violations = health_checker.is_healthy()
if not healthy:
elapsed = time.time() - last_rollback_time
if elapsed < ROLLBACK_COOLDOWN:
logger.warning(f"Skipping rollback, cooldown has {ROLLBACK_COOLDOWN - elapsed:.0f}s remaining")
continue
last_rollback_time = time.time()
safe_rollback()
|
Prediction drift triggers false positives after retraining
When you retrain on new data, the prediction distribution legitimately changes. You need to update the baseline. After promoting a new model and confirming it’s performing well, capture its prediction stats as the new baseline:
1
2
3
| # After collecting enough predictions from the new model
recent_predictions = list(health_checker.predictions)
health_checker.set_baseline(recent_predictions)
|
Or better yet, compute the baseline offline from your validation set and set it during model registration.