The Quick Version#
AI audit logging records every prediction your system makes — what input it received, which model processed it, what it predicted, and who requested it. This isn’t just good practice; regulations like the EU AI Act and NIST AI RMF increasingly require it for high-risk AI systems.
1
| pip install structlog sqlalchemy
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
| import structlog
import json
import uuid
from datetime import datetime, timezone
logger = structlog.get_logger()
def log_prediction(
request_id: str,
model_name: str,
model_version: str,
input_data: dict,
output: dict,
user_id: str = None,
latency_ms: float = None,
) -> dict:
"""Log an AI prediction with full audit context."""
audit_record = {
"event": "ai_prediction",
"request_id": request_id,
"timestamp": datetime.now(timezone.utc).isoformat(),
"model": {
"name": model_name,
"version": model_version,
},
"input": input_data,
"output": output,
"user_id": user_id,
"latency_ms": latency_ms,
}
logger.info("ai_prediction", **audit_record)
return audit_record
# Usage in your inference code
request_id = str(uuid.uuid4())
log_prediction(
request_id=request_id,
model_name="fraud_detector",
model_version="v3.2.1",
input_data={"transaction_amount": 1250.00, "merchant_category": "electronics"},
output={"prediction": "legitimate", "confidence": 0.94, "risk_score": 0.12},
user_id="analyst_jane",
latency_ms=23.5,
)
|
That logs a structured audit record for every prediction. You can query these logs later to answer questions like “why did the model flag this transaction?” or “which model version was running on March 15th?”
Persistent Audit Storage#
For compliance, logs need to be durable, immutable, and queryable. Use a database with append-only semantics:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
| from sqlalchemy import create_engine, Column, String, JSON, Float, DateTime, Integer
from sqlalchemy.orm import declarative_base, sessionmaker
from datetime import datetime, timezone
import uuid
Base = declarative_base()
class AuditLog(Base):
__tablename__ = "ai_audit_logs"
id = Column(Integer, primary_key=True, autoincrement=True)
request_id = Column(String(36), unique=True, nullable=False, index=True)
timestamp = Column(DateTime, nullable=False, index=True)
model_name = Column(String(100), nullable=False, index=True)
model_version = Column(String(50), nullable=False)
input_hash = Column(String(64), nullable=False) # SHA-256 of input
input_data = Column(JSON, nullable=False)
output_data = Column(JSON, nullable=False)
user_id = Column(String(100), index=True)
latency_ms = Column(Float)
metadata = Column(JSON)
engine = create_engine("sqlite:///ai_audit.db")
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
import hashlib
class AuditLogger:
def __init__(self):
self.session = Session()
def log(self, model_name: str, model_version: str, input_data: dict,
output_data: dict, user_id: str = None, latency_ms: float = None,
metadata: dict = None) -> str:
"""Record a prediction in the audit log."""
request_id = str(uuid.uuid4())
input_hash = hashlib.sha256(
json.dumps(input_data, sort_keys=True).encode()
).hexdigest()
record = AuditLog(
request_id=request_id,
timestamp=datetime.now(timezone.utc),
model_name=model_name,
model_version=model_version,
input_hash=input_hash,
input_data=input_data,
output_data=output_data,
user_id=user_id,
latency_ms=latency_ms,
metadata=metadata,
)
self.session.add(record)
self.session.commit()
return request_id
def query_by_model(self, model_name: str, limit: int = 100) -> list:
return self.session.query(AuditLog).filter(
AuditLog.model_name == model_name
).order_by(AuditLog.timestamp.desc()).limit(limit).all()
def query_by_user(self, user_id: str, limit: int = 100) -> list:
return self.session.query(AuditLog).filter(
AuditLog.user_id == user_id
).order_by(AuditLog.timestamp.desc()).limit(limit).all()
def get_by_request_id(self, request_id: str) -> AuditLog:
return self.session.query(AuditLog).filter(
AuditLog.request_id == request_id
).first()
audit = AuditLogger()
rid = audit.log(
model_name="sentiment_classifier",
model_version="v2.1.0",
input_data={"text": "Great product, fast shipping!"},
output_data={"label": "positive", "score": 0.96},
user_id="api_client_42",
latency_ms=15.3,
)
print(f"Logged: {rid}")
|
The input_hash field lets you detect duplicate inputs without storing them twice, and enables efficient lookups for “what did the model predict for this exact input before?”
Middleware for Automatic Logging#
Don’t rely on developers remembering to call the logger. Wrap your inference endpoints with middleware that logs automatically:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
| from fastapi import FastAPI, Request
from functools import wraps
import time
app = FastAPI()
audit = AuditLogger()
def audit_endpoint(model_name: str, model_version: str):
"""Decorator that automatically logs all predictions."""
def decorator(func):
@wraps(func)
async def wrapper(request: Request, *args, **kwargs):
body = await request.json()
user_id = request.headers.get("X-User-ID", "anonymous")
start = time.time()
result = await func(request, *args, **kwargs)
latency = (time.time() - start) * 1000
audit.log(
model_name=model_name,
model_version=model_version,
input_data=body,
output_data=result,
user_id=user_id,
latency_ms=latency,
metadata={
"ip": request.client.host,
"endpoint": str(request.url.path),
},
)
return result
return wrapper
return decorator
@app.post("/predict")
@audit_endpoint(model_name="fraud_detector", model_version="v3.2.1")
async def predict(request: Request):
body = await request.json()
prediction = run_model(body)
return prediction
|
Every request gets logged with zero changes to the prediction code.
Compliance Reports#
Generate reports that auditors and regulators need:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
| from datetime import datetime, timedelta, timezone
from collections import Counter
class ComplianceReporter:
def __init__(self, audit_logger: AuditLogger):
self.audit = audit_logger
def model_usage_report(self, model_name: str, days: int = 30) -> dict:
"""Generate a usage report for a specific model."""
cutoff = datetime.now(timezone.utc) - timedelta(days=days)
records = self.audit.session.query(AuditLog).filter(
AuditLog.model_name == model_name,
AuditLog.timestamp >= cutoff,
).all()
if not records:
return {"model": model_name, "period_days": days, "total_predictions": 0}
predictions = [r.output_data.get("prediction", r.output_data.get("label")) for r in records]
latencies = [r.latency_ms for r in records if r.latency_ms]
versions = Counter(r.model_version for r in records)
return {
"model": model_name,
"period_days": days,
"total_predictions": len(records),
"unique_users": len(set(r.user_id for r in records if r.user_id)),
"prediction_distribution": dict(Counter(predictions)),
"versions_used": dict(versions),
"latency_stats": {
"mean_ms": sum(latencies) / len(latencies) if latencies else 0,
"p95_ms": sorted(latencies)[int(len(latencies) * 0.95)] if latencies else 0,
"max_ms": max(latencies) if latencies else 0,
},
"first_prediction": min(r.timestamp for r in records).isoformat(),
"last_prediction": max(r.timestamp for r in records).isoformat(),
}
def data_access_report(self, user_id: str, days: int = 90) -> dict:
"""Report all AI predictions accessed by a specific user."""
cutoff = datetime.now(timezone.utc) - timedelta(days=days)
records = self.audit.session.query(AuditLog).filter(
AuditLog.user_id == user_id,
AuditLog.timestamp >= cutoff,
).all()
return {
"user_id": user_id,
"period_days": days,
"total_requests": len(records),
"models_used": dict(Counter(r.model_name for r in records)),
"requests": [
{
"request_id": r.request_id,
"timestamp": r.timestamp.isoformat(),
"model": r.model_name,
}
for r in records
],
}
reporter = ComplianceReporter(audit)
report = reporter.model_usage_report("fraud_detector", days=30)
print(json.dumps(report, indent=2))
|
Some audit requirements conflict with privacy requirements. You need to log that a prediction happened without storing sensitive input data:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
| import re
REDACTION_PATTERNS = {
"email": (r'\b[\w.+-]+@[\w-]+\.[\w.-]+\b', "[EMAIL]"),
"phone": (r'\b\d{3}[-.]?\d{3}[-.]?\d{4}\b', "[PHONE]"),
"ssn": (r'\b\d{3}-\d{2}-\d{4}\b', "[SSN]"),
"credit_card": (r'\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b', "[CC]"),
}
def redact_pii(data: dict) -> dict:
"""Redact PII from audit log data while preserving structure."""
redacted = {}
for key, value in data.items():
if isinstance(value, str):
for pattern_name, (pattern, replacement) in REDACTION_PATTERNS.items():
value = re.sub(pattern, replacement, value)
redacted[key] = value
elif isinstance(value, dict):
redacted[key] = redact_pii(value)
else:
redacted[key] = value
return redacted
# Log with redaction
input_data = {"text": "Contact [email protected] or call 555-123-4567", "user_ssn": "123-45-6789"}
redacted = redact_pii(input_data)
print(redacted)
# {'text': 'Contact [EMAIL] or call [PHONE]', 'user_ssn': '[SSN]'}
audit.log(
model_name="classifier",
model_version="v1",
input_data=redacted, # store redacted version
output_data={"label": "inquiry"},
)
|
Common Errors and Fixes#
Audit logging slows down inference
Write logs asynchronously. Use a message queue (Redis, Kafka) as a buffer between your inference service and the audit database. The prediction returns immediately while the log is written in the background.
Database grows too large
Implement a retention policy. Move records older than 1 year to cold storage (S3, GCS) and keep only recent logs in the active database. Compress JSON fields and use columnar storage (Parquet) for archived logs.
Can’t reproduce a past prediction
Log the model version and all preprocessing parameters. If your preprocessing pipeline changes between versions, the same input can produce different outputs. Pin the full pipeline version, not just the model weights.
Audit records modified after the fact
Use append-only storage. In PostgreSQL, revoke UPDATE and DELETE permissions on the audit table. For maximum tamper-resistance, hash each record and chain them (similar to a blockchain) so any modification is detectable.
Missing audit records during high traffic
Async logging can drop records if the queue fills up. Add monitoring on queue depth and set up alerts when it exceeds a threshold. Use persistent queues (Kafka, RabbitMQ with disk persistence) instead of in-memory buffers.