Your ML model serves predictions one at a time, and you’re wasting 80% of your GPU capacity. Here’s how to fix it with prediction caching and request batching.
The Core Solution#
Cache identical predictions in Redis and batch incoming requests to maximize GPU utilization. This cuts response times for repeated queries from 200ms to under 10ms, and increases throughput by 3-5x by filling your GPU’s parallel processing capacity.
Here’s a complete FastAPI service with Redis caching and dynamic batching:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
| import asyncio
import hashlib
import json
from typing import List, Dict, Any
from dataclasses import dataclass
from datetime import timedelta
import redis.asyncio as redis
import torch
from fastapi import FastAPI, BackgroundTasks
from pydantic import BaseModel
from contextlib import asynccontextmanager
# Redis client for prediction cache
redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True)
# Model singleton (lazy loaded)
model = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class PredictionRequest(BaseModel):
text: str
model_version: str = "v1"
class PredictionResponse(BaseModel):
prediction: float
cached: bool
batch_size: int = 1
@dataclass
class BatchItem:
text: str
cache_key: str
future: asyncio.Future
# Batching queue and config
batch_queue: List[BatchItem] = []
batch_lock = asyncio.Lock()
MAX_BATCH_SIZE = 32
MAX_WAIT_MS = 50 # Maximum wait time before forcing a batch
def get_cache_key(text: str, model_version: str) -> str:
"""Generate deterministic cache key from input."""
content = f"{model_version}:{text}"
return f"pred:{hashlib.sha256(content.encode()).hexdigest()[:16]}"
async def get_cached_prediction(cache_key: str) -> float | None:
"""Fetch prediction from Redis cache."""
cached = await redis_client.get(cache_key)
if cached:
return float(cached)
return None
async def cache_prediction(cache_key: str, prediction: float, ttl_seconds: int = 3600):
"""Store prediction in Redis with TTL."""
await redis_client.setex(cache_key, ttl_seconds, str(prediction))
def load_model():
"""Load model once (singleton pattern)."""
global model
if model is None:
# Replace with your actual model
model = torch.nn.Sequential(
torch.nn.Linear(768, 256),
torch.nn.ReLU(),
torch.nn.Linear(256, 1),
torch.nn.Sigmoid()
).to(device)
model.eval()
return model
def batch_predict(texts: List[str]) -> List[float]:
"""Run batched inference on GPU."""
model = load_model()
# Simulate text encoding (replace with your actual encoder)
embeddings = torch.randn(len(texts), 768).to(device)
with torch.no_grad():
predictions = model(embeddings).squeeze().cpu().tolist()
# Handle single prediction case
if isinstance(predictions, float):
predictions = [predictions]
return predictions
async def process_batch():
"""Process accumulated requests as a batch."""
async with batch_lock:
if not batch_queue:
return
# Extract batch items
items = batch_queue.copy()
batch_queue.clear()
texts = [item.text for item in items]
# Run batched inference
predictions = batch_predict(texts)
# Cache results and resolve futures
for item, prediction in zip(items, predictions):
# Cache in background
asyncio.create_task(cache_prediction(item.cache_key, prediction))
# Return result
item.future.set_result(prediction)
async def schedule_batch_processing():
"""Background task that processes batches periodically."""
while True:
await asyncio.sleep(MAX_WAIT_MS / 1000)
if batch_queue:
await process_batch()
@asynccontextmanager
async def lifespan(app):
"""Start background batch processor on startup."""
task = asyncio.create_task(schedule_batch_processing())
yield
task.cancel()
app = FastAPI(lifespan=lifespan)
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
"""Handle prediction with caching and batching."""
cache_key = get_cache_key(request.text, request.model_version)
# Check cache first
cached_pred = await get_cached_prediction(cache_key)
if cached_pred is not None:
return PredictionResponse(
prediction=cached_pred,
cached=True
)
# Add to batch queue
future = asyncio.Future()
batch_item = BatchItem(text=request.text, cache_key=cache_key, future=future)
async with batch_lock:
batch_queue.append(batch_item)
current_batch_size = len(batch_queue)
# Force immediate processing if batch is full
if current_batch_size >= MAX_BATCH_SIZE:
asyncio.create_task(process_batch())
# Wait for batch processing
prediction = await future
return PredictionResponse(
prediction=prediction,
cached=False,
batch_size=current_batch_size
)
@app.get("/metrics")
async def metrics():
"""Get cache hit rate and queue stats."""
# Get Redis stats (requires Redis INFO command)
info = await redis_client.info('stats')
async with batch_lock:
queue_size = len(batch_queue)
return {
"queue_size": queue_size,
"max_batch_size": MAX_BATCH_SIZE,
"max_wait_ms": MAX_WAIT_MS
}
|
This setup handles both caching and batching transparently. Repeated requests return in under 10ms from Redis, while new requests get batched automatically.
Dynamic Batching Configuration#
The sweet spot for batch size and wait time depends on your traffic pattern. High-traffic services should use larger batches (64-128) with shorter waits (20-30ms). Low-traffic APIs need smaller batches (8-16) to avoid adding latency.
Here’s how to tune batching parameters based on measured latency:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
| import time
from collections import deque
from dataclasses import dataclass
from typing import Deque
@dataclass
class BatchMetrics:
"""Track batching performance metrics."""
batch_sizes: Deque[int]
wait_times: Deque[float]
inference_times: Deque[float]
def __init__(self, maxlen: int = 100):
self.batch_sizes = deque(maxlen=maxlen)
self.wait_times = deque(maxlen=maxlen)
self.inference_times = deque(maxlen=maxlen)
def record(self, batch_size: int, wait_time: float, inference_time: float):
self.batch_sizes.append(batch_size)
self.wait_times.append(wait_time)
self.inference_times.append(inference_time)
def avg_batch_size(self) -> float:
return sum(self.batch_sizes) / len(self.batch_sizes) if self.batch_sizes else 0
def avg_wait_time(self) -> float:
return sum(self.wait_times) / len(self.wait_times) if self.wait_times else 0
def avg_inference_time(self) -> float:
return sum(self.inference_times) / len(self.inference_times) if self.inference_times else 0
# Global metrics tracker
metrics_tracker = BatchMetrics()
async def process_batch_with_metrics():
"""Process batch and record performance metrics."""
async with batch_lock:
if not batch_queue:
return
items = batch_queue.copy()
batch_queue.clear()
batch_size = len(items)
wait_start = time.perf_counter()
texts = [item.text for item in items]
inference_start = time.perf_counter()
predictions = batch_predict(texts)
inference_time = time.perf_counter() - inference_start
# Record metrics
metrics_tracker.record(
batch_size=batch_size,
wait_time=inference_start - wait_start,
inference_time=inference_time
)
# Cache and resolve futures
for item, prediction in zip(items, predictions):
asyncio.create_task(cache_prediction(item.cache_key, prediction))
item.future.set_result(prediction)
@app.get("/batch-stats")
async def batch_stats():
"""Return batching performance statistics."""
return {
"avg_batch_size": round(metrics_tracker.avg_batch_size(), 2),
"avg_wait_time_ms": round(metrics_tracker.avg_wait_time() * 1000, 2),
"avg_inference_time_ms": round(metrics_tracker.avg_inference_time() * 1000, 2),
"total_batches": len(metrics_tracker.batch_sizes)
}
def auto_tune_batch_params():
"""Automatically adjust batch parameters based on metrics."""
global MAX_BATCH_SIZE, MAX_WAIT_MS
avg_batch = metrics_tracker.avg_batch_size()
avg_wait = metrics_tracker.avg_wait_time() * 1000 # Convert to ms
# If batches are consistently full, increase max size
if avg_batch > MAX_BATCH_SIZE * 0.9:
MAX_BATCH_SIZE = min(MAX_BATCH_SIZE + 8, 128)
# If batches are small and wait time is high, reduce wait time
if avg_batch < MAX_BATCH_SIZE * 0.3 and avg_wait > MAX_WAIT_MS * 0.8:
MAX_WAIT_MS = max(MAX_WAIT_MS - 10, 20)
# If batches are small and wait time is low, increase wait time
if avg_batch < MAX_BATCH_SIZE * 0.5 and avg_wait < MAX_WAIT_MS * 0.5:
MAX_WAIT_MS = min(MAX_WAIT_MS + 10, 100)
|
Run /batch-stats every few minutes and adjust your config. You want high batch sizes (80%+ of max) without adding too much wait latency (under 50ms).
Measuring Cache Hit Rate#
Cache effectiveness comes down to one metric: hit rate. You need at least 40% hits to justify the Redis overhead. Here’s how to track it properly:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
| from datetime import datetime
class CacheMetrics:
"""Track cache performance."""
def __init__(self):
self.hits = 0
self.misses = 0
self.last_reset = datetime.now()
def record_hit(self):
self.hits += 1
def record_miss(self):
self.misses += 1
def hit_rate(self) -> float:
total = self.hits + self.misses
return self.hits / total if total > 0 else 0.0
def reset(self):
self.hits = 0
self.misses = 0
self.last_reset = datetime.now()
cache_metrics = CacheMetrics()
@app.post("/predict", response_model=PredictionResponse)
async def predict_with_metrics(request: PredictionRequest):
"""Prediction endpoint with cache metrics."""
cache_key = get_cache_key(request.text, request.model_version)
cached_pred = await get_cached_prediction(cache_key)
if cached_pred is not None:
cache_metrics.record_hit()
return PredictionResponse(
prediction=cached_pred,
cached=True
)
cache_metrics.record_miss()
# Batching logic continues as before...
future = asyncio.Future()
batch_item = BatchItem(text=request.text, cache_key=cache_key, future=future)
async with batch_lock:
batch_queue.append(batch_item)
current_batch_size = len(batch_queue)
if current_batch_size >= MAX_BATCH_SIZE:
asyncio.create_task(process_batch_with_metrics())
prediction = await future
return PredictionResponse(
prediction=prediction,
cached=False,
batch_size=current_batch_size
)
@app.get("/cache-stats")
async def cache_stats():
"""Return cache performance statistics."""
return {
"hit_rate": round(cache_metrics.hit_rate() * 100, 2),
"hits": cache_metrics.hits,
"misses": cache_metrics.misses,
"total_requests": cache_metrics.hits + cache_metrics.misses,
"uptime_seconds": (datetime.now() - cache_metrics.last_reset).total_seconds()
}
|
If your hit rate is below 20%, your cache TTL is too short or your traffic is too diverse. Increase TTL from 1 hour to 24 hours for stable models. If hit rate stays low, you’re better off removing caching entirely and focusing on batching.
Common Errors and Fixes#
Redis connection failures: Add retry logic with exponential backoff. Don’t let cache failures kill your service:
1
2
3
4
5
6
7
8
9
10
| async def get_cached_prediction_safe(cache_key: str) -> float | None:
"""Fetch from cache with error handling."""
try:
cached = await redis_client.get(cache_key)
if cached:
return float(cached)
except redis.RedisError as e:
# Log error but don't fail the request
print(f"Cache read error: {e}")
return None
|
Batches timing out: If your model is slow, increase MAX_WAIT_MS to build larger batches. Small batches (under 8 items) waste GPU capacity. Better to wait 100ms and batch 32 items than process 4 items every 20ms.
Memory leaks in batch queue: If your service crashes under high load, you’re not clearing the queue on errors. Wrap process_batch() in a try/except and always clear the queue:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
| async def process_batch_safe():
"""Process batch with error recovery."""
async with batch_lock:
if not batch_queue:
return
items = batch_queue.copy()
batch_queue.clear() # Clear immediately to prevent leaks
try:
texts = [item.text for item in items]
predictions = batch_predict(texts)
for item, prediction in zip(items, predictions):
asyncio.create_task(cache_prediction(item.cache_key, prediction))
item.future.set_result(prediction)
except Exception as e:
# Fail all futures in the batch
for item in items:
item.future.set_exception(e)
|
Cache stampede on cold start: When your service restarts, every request misses the cache and hits the model. Use request coalescing to deduplicate identical in-flight requests:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
| # Track in-flight predictions to avoid duplicate work
in_flight: Dict[str, asyncio.Future] = {}
@app.post("/predict", response_model=PredictionResponse)
async def predict_with_coalescing(request: PredictionRequest):
"""Prediction with request coalescing."""
cache_key = get_cache_key(request.text, request.model_version)
# Check cache
cached_pred = await get_cached_prediction(cache_key)
if cached_pred is not None:
cache_metrics.record_hit()
return PredictionResponse(prediction=cached_pred, cached=True)
cache_metrics.record_miss()
# Check if identical request is already in-flight
if cache_key in in_flight:
prediction = await in_flight[cache_key]
return PredictionResponse(prediction=prediction, cached=False)
# Create future for this request
future = asyncio.Future()
in_flight[cache_key] = future
batch_item = BatchItem(text=request.text, cache_key=cache_key, future=future)
async with batch_lock:
batch_queue.append(batch_item)
if len(batch_queue) >= MAX_BATCH_SIZE:
asyncio.create_task(process_batch_with_metrics())
try:
prediction = await future
return PredictionResponse(prediction=prediction, cached=False)
finally:
# Clean up in-flight tracker
in_flight.pop(cache_key, None)
|
This setup handles 1000+ requests/second on a single GPU with sub-50ms latency for cache hits and 80%+ GPU utilization. Use it as a starting point and tune batch size and TTL based on your actual traffic patterns.