Every time you deploy a model serving container, it downloads the same 4 GB weights file from S3. Cold starts take minutes. Your autoscaler spins up a new pod and it sits there pulling bytes while requests queue up. The fix is a two-tier cache: check local disk first, fall back to S3, and only hit the origin (HuggingFace Hub, your model registry) as a last resort.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
import os
import hashlib
import shutil
import time
import logging
from pathlib import Path
from dataclasses import dataclass, field

import boto3
from botocore.exceptions import ClientError

logger = logging.getLogger(__name__)

CACHE_DIR = os.environ.get("MODEL_CACHE_DIR", "/tmp/model_cache")
S3_BUCKET = os.environ.get("MODEL_CACHE_S3_BUCKET", "my-model-artifacts")
MAX_CACHE_BYTES = int(os.environ.get("MODEL_CACHE_MAX_BYTES", str(10 * 1024**3)))  # 10 GB default

That gives you the three knobs you need: where to store files locally, which S3 bucket to use, and how big the local cache can grow before eviction kicks in.

The ModelCache Class

The core idea is simple. get() checks local disk, then S3. put() writes to both. evict() removes from both. Every artifact is keyed by model_name/version and verified with a SHA256 checksum stored alongside it.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
@dataclass
class CacheEntry:
    model_name: str
    version: str
    local_path: Path
    size_bytes: int
    last_accessed: float = field(default_factory=time.time)


class ModelCache:
    def __init__(
        self,
        cache_dir: str = CACHE_DIR,
        s3_bucket: str = S3_BUCKET,
        max_cache_bytes: int = MAX_CACHE_BYTES,
    ):
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        self.s3_bucket = s3_bucket
        self.max_cache_bytes = max_cache_bytes
        self.s3 = boto3.client("s3")
        self._entries: dict[str, CacheEntry] = {}
        self._rebuild_index()

    def _cache_key(self, model_name: str, version: str) -> str:
        return f"{model_name}/{version}"

    def _local_path(self, model_name: str, version: str) -> Path:
        return self.cache_dir / model_name / version

    def _s3_key(self, model_name: str, version: str) -> str:
        return f"cache/{model_name}/{version}/model.tar.gz"

    def _checksum_s3_key(self, model_name: str, version: str) -> str:
        return f"cache/{model_name}/{version}/sha256.txt"

    def _rebuild_index(self):
        """Scan the local cache directory and rebuild the in-memory index."""
        self._entries.clear()
        if not self.cache_dir.exists():
            return
        for model_dir in self.cache_dir.iterdir():
            if not model_dir.is_dir():
                continue
            for version_dir in model_dir.iterdir():
                if not version_dir.is_dir():
                    continue
                artifact = version_dir / "model.tar.gz"
                if artifact.exists():
                    key = self._cache_key(model_dir.name, version_dir.name)
                    self._entries[key] = CacheEntry(
                        model_name=model_dir.name,
                        version=version_dir.name,
                        local_path=artifact,
                        size_bytes=artifact.stat().st_size,
                        last_accessed=artifact.stat().st_atime,
                    )

    @staticmethod
    def sha256_file(filepath: Path) -> str:
        h = hashlib.sha256()
        with open(filepath, "rb") as f:
            for chunk in iter(lambda: f.read(8192), b""):
                h.update(chunk)
        return h.hexdigest()

    def _current_cache_size(self) -> int:
        return sum(e.size_bytes for e in self._entries.values())

    def _evict_lru(self, needed_bytes: int):
        """Remove least-recently-accessed entries until we have room."""
        while self._current_cache_size() + needed_bytes > self.max_cache_bytes:
            if not self._entries:
                break
            oldest_key = min(self._entries, key=lambda k: self._entries[k].last_accessed)
            entry = self._entries.pop(oldest_key)
            parent_dir = entry.local_path.parent
            shutil.rmtree(parent_dir, ignore_errors=True)
            logger.info("Evicted %s from local cache (LRU)", oldest_key)

    def get(self, model_name: str, version: str) -> Path | None:
        """Fetch a model artifact. Checks local cache, then S3."""
        key = self._cache_key(model_name, version)
        local_dir = self._local_path(model_name, version)
        artifact = local_dir / "model.tar.gz"
        checksum_file = local_dir / "sha256.txt"

        # Check local cache
        if key in self._entries and artifact.exists():
            self._entries[key].last_accessed = time.time()
            os.utime(artifact, None)  # touch atime
            logger.info("Cache HIT (local): %s", key)
            return artifact

        # Try S3
        s3_key = self._s3_key(model_name, version)
        try:
            head = self.s3.head_object(Bucket=self.s3_bucket, Key=s3_key)
            remote_size = head["ContentLength"]
        except ClientError as e:
            if e.response["Error"]["Code"] == "404":
                logger.info("Cache MISS (not in S3): %s", key)
                return None
            raise

        # Make room locally
        self._evict_lru(remote_size)
        local_dir.mkdir(parents=True, exist_ok=True)

        # Download artifact
        logger.info("Downloading %s from S3 (%d MB)...", key, remote_size // (1024 * 1024))
        self.s3.download_file(self.s3_bucket, s3_key, str(artifact))

        # Verify checksum
        try:
            checksum_resp = self.s3.get_object(
                Bucket=self.s3_bucket, Key=self._checksum_s3_key(model_name, version)
            )
            expected_hash = checksum_resp["Body"].read().decode("utf-8").strip()
            actual_hash = self.sha256_file(artifact)
            if actual_hash != expected_hash:
                artifact.unlink(missing_ok=True)
                logger.error("Checksum mismatch for %s: expected %s, got %s", key, expected_hash, actual_hash)
                return None
            checksum_file.write_text(expected_hash)
        except ClientError:
            logger.warning("No checksum found in S3 for %s, skipping verification", key)

        # Register in index
        self._entries[key] = CacheEntry(
            model_name=model_name,
            version=version,
            local_path=artifact,
            size_bytes=remote_size,
            last_accessed=time.time(),
        )
        logger.info("Cache HIT (S3): %s", key)
        return artifact

    def put(self, model_name: str, version: str, source_path: str) -> str:
        """Upload an artifact to S3 and store locally. Returns the SHA256 hash."""
        key = self._cache_key(model_name, version)
        source = Path(source_path)
        if not source.exists():
            raise FileNotFoundError(f"Source artifact not found: {source_path}")

        file_hash = self.sha256_file(source)
        file_size = source.stat().st_size

        # Upload to S3
        s3_key = self._s3_key(model_name, version)
        logger.info("Uploading %s to S3 (%d MB)...", key, file_size // (1024 * 1024))
        self.s3.upload_file(str(source), self.s3_bucket, s3_key)
        self.s3.put_object(
            Bucket=self.s3_bucket,
            Key=self._checksum_s3_key(model_name, version),
            Body=file_hash.encode("utf-8"),
        )

        # Copy to local cache
        self._evict_lru(file_size)
        local_dir = self._local_path(model_name, version)
        local_dir.mkdir(parents=True, exist_ok=True)
        artifact = local_dir / "model.tar.gz"
        shutil.copy2(str(source), str(artifact))
        (local_dir / "sha256.txt").write_text(file_hash)

        self._entries[key] = CacheEntry(
            model_name=model_name,
            version=version,
            local_path=artifact,
            size_bytes=file_size,
            last_accessed=time.time(),
        )
        logger.info("Cached %s (sha256: %s)", key, file_hash[:16])
        return file_hash

    def evict(self, model_name: str, version: str):
        """Remove an artifact from both local cache and S3."""
        key = self._cache_key(model_name, version)

        # Remove local
        local_dir = self._local_path(model_name, version)
        if local_dir.exists():
            shutil.rmtree(local_dir)
        self._entries.pop(key, None)

        # Remove from S3
        self.s3.delete_objects(
            Bucket=self.s3_bucket,
            Delete={
                "Objects": [
                    {"Key": self._s3_key(model_name, version)},
                    {"Key": self._checksum_s3_key(model_name, version)},
                ]
            },
        )
        logger.info("Evicted %s from all caches", key)

That’s the full class. A few design choices worth calling out: the LRU eviction runs before downloading so you don’t blow past your disk limit. The checksum verification is optional — if someone stored an artifact without a checksum file, the download still works but logs a warning. The in-memory index gets rebuilt from disk on startup, so you survive process restarts without losing your cache state.

Integrating with FastAPI

Here’s how to wire the cache into a model serving endpoint. This uses the lifespan context manager pattern — the correct way to handle startup/shutdown in modern FastAPI.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import tarfile
import tempfile

cache: ModelCache | None = None
loaded_model = None  # Replace with your actual model type

def load_model_from_artifact(artifact_path: Path):
    """Extract tarball and load model weights. Replace with your actual loading logic."""
    extract_dir = tempfile.mkdtemp()
    with tarfile.open(artifact_path, "r:gz") as tar:
        tar.extractall(path=extract_dir)
    # Your model loading here, e.g.:
    # return torch.load(os.path.join(extract_dir, "model.pt"))
    return {"status": "loaded", "path": extract_dir}

@asynccontextmanager
async def lifespan(app: FastAPI):
    global cache, loaded_model
    model_name = os.environ.get("MODEL_NAME", "text-classifier")
    model_version = os.environ.get("MODEL_VERSION", "v1.2.0")

    cache = ModelCache()
    artifact_path = cache.get(model_name, model_version)

    if artifact_path is None:
        raise RuntimeError(f"Model {model_name}:{model_version} not found in any cache tier")

    loaded_model = load_model_from_artifact(artifact_path)
    logger.info("Model %s:%s loaded and ready", model_name, model_version)
    yield
    loaded_model = None

app = FastAPI(lifespan=lifespan)

class PredictRequest(BaseModel):
    text: str

class PredictResponse(BaseModel):
    label: str
    confidence: float

@app.post("/predict", response_model=PredictResponse)
async def predict(req: PredictRequest):
    if loaded_model is None:
        raise HTTPException(status_code=503, detail="Model not loaded")
    # Replace with actual inference
    return PredictResponse(label="positive", confidence=0.95)

@app.get("/cache/stats")
async def cache_stats():
    if cache is None:
        raise HTTPException(status_code=503, detail="Cache not initialized")
    entries = []
    for key, entry in cache._entries.items():
        entries.append({
            "key": key,
            "size_mb": round(entry.size_bytes / (1024 * 1024), 1),
            "last_accessed": entry.last_accessed,
        })
    return {
        "total_entries": len(cache._entries),
        "total_size_mb": round(cache._current_cache_size() / (1024 * 1024), 1),
        "max_size_mb": round(cache.max_cache_bytes / (1024 * 1024), 1),
        "entries": entries,
    }

The /cache/stats endpoint is useful for debugging in production. Hit it to see what’s cached, how much disk you’re using, and when each artifact was last accessed.

Cache Hit/Miss Metrics

Logging alone won’t cut it at scale. You want counters you can scrape with Prometheus or push to your metrics backend. Here’s a lightweight wrapper that tracks hit rates.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from dataclasses import dataclass, field
from enum import Enum

class CacheTier(Enum):
    LOCAL = "local"
    S3 = "s3"
    MISS = "miss"

@dataclass
class CacheMetrics:
    hits_local: int = 0
    hits_s3: int = 0
    misses: int = 0
    evictions: int = 0
    put_count: int = 0
    download_bytes: int = 0

    def record_hit(self, tier: CacheTier):
        if tier == CacheTier.LOCAL:
            self.hits_local += 1
        elif tier == CacheTier.S3:
            self.hits_s3 += 1
        else:
            self.misses += 1

    def hit_rate(self) -> float:
        total = self.hits_local + self.hits_s3 + self.misses
        if total == 0:
            return 0.0
        return (self.hits_local + self.hits_s3) / total

    def local_hit_rate(self) -> float:
        total = self.hits_local + self.hits_s3 + self.misses
        if total == 0:
            return 0.0
        return self.hits_local / total

    def to_dict(self) -> dict:
        return {
            "hits_local": self.hits_local,
            "hits_s3": self.hits_s3,
            "misses": self.misses,
            "evictions": self.evictions,
            "put_count": self.put_count,
            "download_bytes_mb": round(self.download_bytes / (1024 * 1024), 1),
            "hit_rate": round(self.hit_rate(), 4),
            "local_hit_rate": round(self.local_hit_rate(), 4),
        }

Drop a CacheMetrics instance into your ModelCache.__init__ and call record_hit() in the appropriate branches of get(). Then expose metrics.to_dict() on a /cache/metrics endpoint. In a real deployment, you’d also export these as Prometheus gauges, but the pattern is the same.

Common Errors and Fixes

botocore.exceptions.NoCredentialsError — Your container doesn’t have AWS credentials. If running on ECS or EKS, attach an IAM role to the task/pod. For local development, set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY or use aws configure.

FileNotFoundError: /tmp/model_cache/... — The local cache directory doesn’t exist or the container’s filesystem is read-only. Make sure MODEL_CACHE_DIR points to a writable volume. On Kubernetes, mount an emptyDir volume at that path.

Checksum mismatch after download — Usually means a partial download or a corrupted file in S3. Re-upload the artifact with put() to regenerate the checksum. You can also add retry logic around the download_file call — boto3 retries on transient errors by default, but network interruptions mid-transfer won’t be caught.

Cache fills up instantly with large models — Set MODEL_CACHE_MAX_BYTES to something reasonable for your disk. If you’re serving a single 7B model (~14 GB in fp16), you need at least 15-16 GB of cache space. The LRU eviction only helps when you have multiple model versions rotating through.

S3 SlowDown errors under heavy autoscaling — When 50 pods all start at once, they all hit S3 simultaneously. Add jitter to your startup: time.sleep(random.uniform(0, 5)) before the first cache.get() call. Or better yet, pre-warm your local cache by baking the most common model into your container image.

tarfile.ReadError: not a gzip file — The artifact in your cache isn’t actually a gzipped tarball. This happens when someone uploads a raw model file instead of a tarball. Either enforce the tar.gz format in put() with a check, or adapt load_model_from_artifact to handle both formats.