Pushing a new model straight to 100% of traffic is asking for trouble. A gradual rollout lets you send a small slice of requests to the new model, compare metrics against the current one, and promote or roll back automatically. Here’s the full pipeline: Redis-backed feature flags, percentage-based routing in FastAPI, metric collection, and automated promotion logic.
1
2
3
4
5
6
7
8
9
10
11
12
| import redis
import json
r = redis.Redis(host="localhost", port=6379, db=0, decode_responses=True)
# Set up a rollout: model_v2 gets 10% of traffic
r.hset("rollout:prediction_model", mapping={
"champion": "model_v1",
"challenger": "model_v2",
"percentage": 10,
"status": "active",
})
|
That’s your feature flag. A hash in Redis that stores which models are in play and what percentage of traffic hits the challenger. Everything else builds on top of this.
The Feature Flag Store#
You need a thin wrapper around Redis so the rest of your code doesn’t care about serialization details. This class reads and writes rollout configs, and exposes a simple route_request method that picks champion or challenger based on the configured percentage.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
| import redis
import random
from dataclasses import dataclass
@dataclass
class RolloutConfig:
champion: str
challenger: str
percentage: int # 0-100, percentage routed to challenger
status: str # "active", "paused", "promoted", "rolled_back"
class FeatureFlagStore:
def __init__(self, redis_url: str = "redis://localhost:6379/0"):
self.r = redis.Redis.from_url(redis_url, decode_responses=True)
def get_rollout(self, name: str) -> RolloutConfig | None:
data = self.r.hgetall(f"rollout:{name}")
if not data:
return None
return RolloutConfig(
champion=data["champion"],
challenger=data["challenger"],
percentage=int(data["percentage"]),
status=data["status"],
)
def set_rollout(self, name: str, config: RolloutConfig) -> None:
self.r.hset(f"rollout:{name}", mapping={
"champion": config.champion,
"challenger": config.challenger,
"percentage": config.percentage,
"status": config.status,
})
def route_request(self, name: str) -> str:
"""Returns the model name to use for this request."""
config = self.get_rollout(name)
if config is None or config.status != "active":
# No active rollout — use champion or fall back
return config.champion if config else "model_v1"
if random.randint(1, 100) <= config.percentage:
return config.challenger
return config.champion
def update_percentage(self, name: str, new_pct: int) -> None:
self.r.hset(f"rollout:{name}", "percentage", min(max(new_pct, 0), 100))
|
route_request rolls a random number between 1 and 100. If it falls within the percentage, the request goes to the challenger. Simple, stateless, fast.
The FastAPI Serving Layer#
The serving layer loads both models, routes each request through the feature flag store, and records which model handled it. We use FastAPI’s lifespan context manager to initialize Redis and models on startup.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
| from contextlib import asynccontextmanager
from fastapi import FastAPI
from pydantic import BaseModel
import time
MODELS: dict[str, object] = {}
flag_store: FeatureFlagStore | None = None
def load_model(model_name: str):
"""Replace this with your actual model loading logic."""
# Simulated model — in production, load from MLflow, S3, etc.
return lambda features: {"prediction": sum(features) / len(features), "model": model_name}
@asynccontextmanager
async def lifespan(app: FastAPI):
global flag_store
flag_store = FeatureFlagStore("redis://localhost:6379/0")
# Load both champion and challenger
rollout = flag_store.get_rollout("prediction_model")
if rollout:
MODELS[rollout.champion] = load_model(rollout.champion)
MODELS[rollout.challenger] = load_model(rollout.challenger)
else:
MODELS["model_v1"] = load_model("model_v1")
yield
MODELS.clear()
app = FastAPI(lifespan=lifespan)
class PredictionRequest(BaseModel):
features: list[float]
class PredictionResponse(BaseModel):
prediction: float
model_used: str
latency_ms: float
@app.post("/predict", response_model=PredictionResponse)
async def predict(req: PredictionRequest):
model_name = flag_store.route_request("prediction_model")
model_fn = MODELS.get(model_name)
if model_fn is None:
# Fallback if model isn't loaded
model_fn = MODELS.get("model_v1", lambda f: {"prediction": 0.0, "model": "fallback"})
start = time.perf_counter()
result = model_fn(req.features)
latency_ms = (time.perf_counter() - start) * 1000
# Record metric for later comparison
record_metric(model_name, latency_ms, result["prediction"])
return PredictionResponse(
prediction=result["prediction"],
model_used=model_name,
latency_ms=round(latency_ms, 2),
)
|
Every prediction logs which model served it and how long it took. That data feeds the comparison pipeline.
Metric Collection and Comparison#
Store per-model metrics in Redis sorted sets or lists. For a production system you’d use Prometheus or a time-series database, but Redis works fine to show the pattern.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
| import json
import time
import statistics
def record_metric(model_name: str, latency_ms: float, prediction: float) -> None:
"""Push a metric entry into a Redis list for the model."""
r = redis.Redis.from_url("redis://localhost:6379/0", decode_responses=True)
entry = json.dumps({
"latency_ms": latency_ms,
"prediction": prediction,
"timestamp": time.time(),
})
r.lpush(f"metrics:{model_name}", entry)
r.ltrim(f"metrics:{model_name}", 0, 9999) # Keep last 10k entries
def get_metrics(model_name: str, last_n: int = 1000) -> list[dict]:
r = redis.Redis.from_url("redis://localhost:6379/0", decode_responses=True)
raw = r.lrange(f"metrics:{model_name}", 0, last_n - 1)
return [json.loads(entry) for entry in raw]
def compare_models(champion: str, challenger: str) -> dict:
"""Compare latency and prediction distributions between two models."""
champ_metrics = get_metrics(champion)
chall_metrics = get_metrics(challenger)
if len(champ_metrics) < 50 or len(chall_metrics) < 50:
return {"decision": "wait", "reason": "not enough data"}
champ_latencies = [m["latency_ms"] for m in champ_metrics]
chall_latencies = [m["latency_ms"] for m in chall_metrics]
champ_p50 = statistics.median(champ_latencies)
chall_p50 = statistics.median(chall_latencies)
champ_p99 = sorted(champ_latencies)[int(len(champ_latencies) * 0.99)]
chall_p99 = sorted(chall_latencies)[int(len(chall_latencies) * 0.99)]
# Challenger must not regress p50 latency by more than 20%
latency_ok = chall_p50 <= champ_p50 * 1.2
# Challenger must not blow up tail latency
tail_ok = chall_p99 <= champ_p99 * 1.5
if latency_ok and tail_ok:
return {
"decision": "promote",
"champ_p50": round(champ_p50, 2),
"chall_p50": round(chall_p50, 2),
"champ_p99": round(champ_p99, 2),
"chall_p99": round(chall_p99, 2),
}
else:
return {
"decision": "rollback",
"reason": f"latency_ok={latency_ok}, tail_ok={tail_ok}",
"champ_p50": round(champ_p50, 2),
"chall_p50": round(chall_p50, 2),
}
|
The compare_models function checks two things: median latency regression and tail latency regression. You can add accuracy, error rate, or any business metric the same way. The key is that the decision is automatic and based on real traffic data.
Wire the comparison into a promotion pipeline. This runs on a schedule (cron, Celery beat, or a simple loop) and steps the percentage up if things look good, or rolls back if they don’t.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
| def run_promotion_step(rollout_name: str = "prediction_model") -> str:
store = FeatureFlagStore("redis://localhost:6379/0")
config = store.get_rollout(rollout_name)
if config is None or config.status != "active":
return "no active rollout"
result = compare_models(config.champion, config.challenger)
if result["decision"] == "wait":
return "waiting for more data"
if result["decision"] == "rollback":
config.status = "rolled_back"
config.percentage = 0
store.set_rollout(rollout_name, config)
return f"rolled back: {result.get('reason', 'metrics failed')}"
# Promote: step up the percentage
steps = [10, 25, 50, 75, 100]
current = config.percentage
next_pct = 100
for step in steps:
if step > current:
next_pct = step
break
if next_pct >= 100:
# Full promotion — challenger becomes the new champion
config.champion = config.challenger
config.challenger = ""
config.percentage = 0
config.status = "promoted"
store.set_rollout(rollout_name, config)
return "fully promoted"
else:
store.update_percentage(rollout_name, next_pct)
return f"stepped up to {next_pct}%"
|
The rollout follows a staircase: 10% -> 25% -> 50% -> 75% -> 100%. At each step, metrics are compared. If the challenger regresses on latency, the whole rollout snaps back to 0%. If it reaches 100% cleanly, the challenger becomes the new champion.
You can trigger this from a management endpoint too:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
| @app.post("/rollout/step")
async def step_rollout():
result = run_promotion_step()
return {"result": result}
@app.get("/rollout/status")
async def rollout_status():
config = flag_store.get_rollout("prediction_model")
if config is None:
return {"status": "no rollout configured"}
return {
"champion": config.champion,
"challenger": config.challenger,
"percentage": config.percentage,
"status": config.status,
}
|
Hit /rollout/step from a cron job every 5-10 minutes. Check /rollout/status to see where you are.
Common Errors and Fixes#
Redis connection refused on startup. Make sure Redis is running before you start FastAPI. The lifespan function will crash if it can’t reach Redis. Run redis-cli ping to verify.
1
2
| redis-cli ping
# Should return: PONG
|
Challenger model not loaded after updating rollout config. The lifespan function only loads models at startup. If you add a new challenger while the server is running, you need to either restart the server or add a hot-reload endpoint that loads the new model into the MODELS dict.
Metrics comparison always returns “wait”. You need at least 50 data points per model before comparison kicks in. If your challenger is at 10% and traffic is low, it takes a while to accumulate. Lower the threshold in compare_models for testing, but keep it at 50+ in production to avoid noisy decisions.
Random routing is uneven at low traffic. With 10% rollout and only 20 requests, you might see 0 or 5 go to the challenger. That’s normal variance with small samples. The percentages converge as request volume grows. Don’t panic if early numbers look off.
Rollout stuck at “rolled_back” after a fix. Once you deploy a fixed challenger model, you need to manually reset the rollout status back to “active” and set the percentage to 10 again. The pipeline won’t auto-restart a rolled-back experiment.
1
| redis-cli hset rollout:prediction_model status active percentage 10 challenger model_v3
|