The Core Pattern#
When an AI model goes sideways in production, you need three things happening in sequence: detect the anomaly, stop the bleeding, and restore the last known good state. Most teams only have alerting. The ones who sleep well at night have automated rollback.
Here’s the architecture in a nutshell: a health monitor watches prediction distributions and error rates, a circuit breaker trips when thresholds are crossed, a rollback controller swaps in the previous model version, and an incident logger captures everything for the post-mortem.
1
| pip install numpy scipy
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
| import time
import hashlib
import json
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Optional
from scipy import stats
import numpy as np
@dataclass
class ModelHealthMonitor:
"""Tracks prediction distributions and error rates for a deployed model."""
model_name: str
model_version: str
window_size: int = 500
error_rate_threshold: float = 0.05
drift_p_value_threshold: float = 0.01
predictions: list = field(default_factory=list)
baseline_predictions: list = field(default_factory=list)
errors: list = field(default_factory=list)
_baseline_set: bool = False
def set_baseline(self, baseline_preds: list[float]):
"""Set the baseline prediction distribution from validation or previous stable period."""
self.baseline_predictions = list(baseline_preds)
self._baseline_set = True
def record_prediction(self, value: float, is_error: bool = False):
self.predictions.append(value)
self.errors.append(1 if is_error else 0)
# Keep a rolling window
if len(self.predictions) > self.window_size:
self.predictions = self.predictions[-self.window_size:]
self.errors = self.errors[-self.window_size:]
def error_rate(self) -> float:
if not self.errors:
return 0.0
return sum(self.errors) / len(self.errors)
def detect_distribution_shift(self) -> dict:
"""Run a KS test comparing recent predictions against baseline."""
if not self._baseline_set or len(self.predictions) < 50:
return {"shifted": False, "reason": "insufficient data"}
stat, p_value = stats.ks_2samp(self.baseline_predictions, self.predictions)
shifted = p_value < self.drift_p_value_threshold
return {
"shifted": shifted,
"ks_statistic": round(stat, 4),
"p_value": round(p_value, 6),
}
def health_check(self) -> dict:
current_error_rate = self.error_rate()
drift = self.detect_distribution_shift()
healthy = (
current_error_rate < self.error_rate_threshold
and not drift["shifted"]
)
return {
"model": self.model_name,
"version": self.model_version,
"healthy": healthy,
"error_rate": round(current_error_rate, 4),
"distribution_drift": drift,
"sample_count": len(self.predictions),
"checked_at": datetime.now(timezone.utc).isoformat(),
}
|
Use set_baseline() with predictions from your validation set or from a known stable production window. The monitor uses a two-sample Kolmogorov-Smirnov test to catch distribution shifts – if your fraud model suddenly starts scoring everything at 0.9 instead of a normal spread, this catches it before your ops team notices the Slack complaints.
Circuit Breaker and Kill Switch#
A circuit breaker prevents a broken model from serving predictions. This is the “stop the bleeding” step. Once the health monitor flags a problem, the circuit breaker opens and all inference requests get routed to a fallback – either a previous model version or a simple heuristic.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
| from enum import Enum
class CircuitState(Enum):
CLOSED = "closed" # Normal operation
OPEN = "open" # Model disabled, using fallback
HALF_OPEN = "half_open" # Testing if model recovered
@dataclass
class CircuitBreaker:
"""Circuit breaker for AI model inference."""
failure_threshold: int = 10
recovery_timeout_sec: float = 300.0 # 5 minutes before retry
half_open_max_calls: int = 5
state: CircuitState = CircuitState.CLOSED
failure_count: int = 0
last_failure_time: float = 0.0
half_open_successes: int = 0
def record_success(self):
if self.state == CircuitState.HALF_OPEN:
self.half_open_successes += 1
if self.half_open_successes >= self.half_open_max_calls:
self.state = CircuitState.CLOSED
self.failure_count = 0
self.half_open_successes = 0
else:
self.failure_count = max(0, self.failure_count - 1)
def record_failure(self):
self.failure_count += 1
self.last_failure_time = time.time()
if self.failure_count >= self.failure_threshold:
self.state = CircuitState.OPEN
def allow_request(self) -> bool:
if self.state == CircuitState.CLOSED:
return True
if self.state == CircuitState.OPEN:
elapsed = time.time() - self.last_failure_time
if elapsed >= self.recovery_timeout_sec:
self.state = CircuitState.HALF_OPEN
self.half_open_successes = 0
return True
return False
# HALF_OPEN: let limited traffic through
return True
def force_open(self):
"""Manual kill switch -- immediately stop all model traffic."""
self.state = CircuitState.OPEN
self.last_failure_time = time.time()
def force_close(self):
"""Manual override to resume traffic."""
self.state = CircuitState.CLOSED
self.failure_count = 0
|
The force_open() method is your kill switch. Wire this to an API endpoint or a CLI command so anyone on-call can slam it in an emergency. The recovery timeout gives your team breathing room – the circuit stays open for 5 minutes before tentatively allowing traffic through in half-open mode.
Automated Rollback Controller#
Once the circuit breaker trips, you want the system to swap in the last known good model version without human intervention. Here’s a rollback controller that manages model version state and executes the swap.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
| @dataclass
class ModelVersion:
version: str
artifact_path: str
deployed_at: str
validation_score: float
is_active: bool = False
class RollbackController:
"""Manages model versions and performs automated rollback on failure."""
def __init__(self, model_name: str):
self.model_name = model_name
self.versions: list[ModelVersion] = []
self.incident_log: list[dict] = []
def register_version(self, version: ModelVersion):
self.versions.append(version)
def get_active_version(self) -> Optional[ModelVersion]:
for v in self.versions:
if v.is_active:
return v
return None
def get_previous_version(self) -> Optional[ModelVersion]:
active = self.get_active_version()
if not active:
return None
candidates = [
v for v in self.versions
if v.version != active.version and not v.is_active
]
if not candidates:
return None
# Pick the most recently deployed non-active version
return max(candidates, key=lambda v: v.deployed_at)
def rollback(self, reason: str, health_snapshot: dict) -> dict:
"""Roll back to the previous model version and log the incident."""
active = self.get_active_version()
target = self.get_previous_version()
if not active or not target:
return {"success": False, "error": "No rollback target available"}
# Deactivate current, activate previous
active.is_active = False
target.is_active = True
incident = {
"incident_id": hashlib.sha256(
f"{self.model_name}-{datetime.now(timezone.utc).isoformat()}".encode()
).hexdigest()[:12],
"model": self.model_name,
"rolled_back_from": active.version,
"rolled_back_to": target.version,
"reason": reason,
"health_snapshot": health_snapshot,
"timestamp": datetime.now(timezone.utc).isoformat(),
"status": "automatic_rollback",
}
self.incident_log.append(incident)
return {"success": True, "incident": incident}
def get_incidents(self) -> list[dict]:
return list(self.incident_log)
|
In a real deployment, the rollback() method would also call your model serving infrastructure – updating a Kubernetes deployment, swapping a SageMaker endpoint, or changing the model path in your load balancer config. The key point is that the version tracking and incident logging happen atomically with the swap.
Wiring It All Together#
Here’s how these pieces connect in a monitoring loop:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
| def run_health_loop(
monitor: ModelHealthMonitor,
breaker: CircuitBreaker,
controller: RollbackController,
check_interval_sec: float = 30.0,
):
"""Periodic health check that triggers rollback when the model degrades."""
while True:
health = monitor.health_check()
if not health["healthy"]:
breaker.record_failure()
print(f"[WARN] Unhealthy check: error_rate={health['error_rate']}, "
f"drift={health['distribution_drift']}")
if not breaker.allow_request():
result = controller.rollback(
reason=f"error_rate={health['error_rate']}, "
f"drift={health['distribution_drift']}",
health_snapshot=health,
)
if result["success"]:
incident = result["incident"]
print(f"[ROLLBACK] {incident['rolled_back_from']} -> "
f"{incident['rolled_back_to']} "
f"(incident: {incident['incident_id']})")
# Reset the breaker for the new model version
breaker.force_close()
monitor.predictions.clear()
monitor.errors.clear()
else:
breaker.record_success()
time.sleep(check_interval_sec)
|
Set check_interval_sec based on your traffic volume. High-traffic services can check every 10-15 seconds. Lower-traffic ones might check every minute or two. The important thing is that the loop is autonomous – no pager required for the initial response.
Incident Logging for Post-Mortems#
Every automated rollback should produce a structured incident record. The RollbackController above captures this, but you should also persist these to durable storage. Ship them to your logging pipeline alongside your normal application logs.
A good incident record answers five questions: what model failed, when it failed, what the symptoms were (error rate, drift metrics), what action was taken, and what version replaced it. The health_snapshot field captures the raw numbers so you can reconstruct the failure during the post-mortem without guessing.
Common Errors and Fixes#
Health monitor reports drift on every check after deployment
Your baseline is stale. When you deploy a new model version, you need to call set_baseline() with predictions from that version’s validation set. A common mistake is keeping the baseline from the old model, which guarantees the KS test flags a shift.
Circuit breaker flaps between open and closed
Your failure_threshold is too low or your recovery_timeout_sec is too short. Start with a threshold of 10 failures and a 5-minute timeout. Increase the timeout if the underlying issue takes longer to resolve – for example, if rollback involves a container restart that takes 2 minutes.
Rollback target has worse performance than the broken model
This happens when teams register every deployed version without tracking validation scores. Filter your rollback candidates by validation_score – only roll back to versions that passed your quality bar. Add a check in get_previous_version() to skip versions below a minimum score.
KS test gives false positives on low traffic
The two-sample KS test needs a reasonable sample size to be reliable. If you have fewer than 50 predictions in your window, skip the drift check entirely and rely on error rate alone. The monitor code above already handles this, but make sure your window_size matches your traffic patterns.
Kill switch endpoint is behind the same load balancer as the broken model
Put your kill switch API on a separate service or at minimum a separate health check path. If the model serving infrastructure is down, you need the kill switch to still be reachable. A simple Redis key or feature flag service works well for this.