The Quick Version

AI audit logging records every prediction your system makes — what input it received, which model processed it, what it predicted, and who requested it. This isn’t just good practice; regulations like the EU AI Act and NIST AI RMF increasingly require it for high-risk AI systems.

1
pip install structlog sqlalchemy
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import structlog
import json
import uuid
from datetime import datetime, timezone

logger = structlog.get_logger()

def log_prediction(
    request_id: str,
    model_name: str,
    model_version: str,
    input_data: dict,
    output: dict,
    user_id: str = None,
    latency_ms: float = None,
) -> dict:
    """Log an AI prediction with full audit context."""
    audit_record = {
        "event": "ai_prediction",
        "request_id": request_id,
        "timestamp": datetime.now(timezone.utc).isoformat(),
        "model": {
            "name": model_name,
            "version": model_version,
        },
        "input": input_data,
        "output": output,
        "user_id": user_id,
        "latency_ms": latency_ms,
    }

    logger.info("ai_prediction", **audit_record)
    return audit_record

# Usage in your inference code
request_id = str(uuid.uuid4())
log_prediction(
    request_id=request_id,
    model_name="fraud_detector",
    model_version="v3.2.1",
    input_data={"transaction_amount": 1250.00, "merchant_category": "electronics"},
    output={"prediction": "legitimate", "confidence": 0.94, "risk_score": 0.12},
    user_id="analyst_jane",
    latency_ms=23.5,
)

That logs a structured audit record for every prediction. You can query these logs later to answer questions like “why did the model flag this transaction?” or “which model version was running on March 15th?”

Persistent Audit Storage

For compliance, logs need to be durable, immutable, and queryable. Use a database with append-only semantics:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from sqlalchemy import create_engine, Column, String, JSON, Float, DateTime, Integer
from sqlalchemy.orm import declarative_base, sessionmaker
from datetime import datetime, timezone
import uuid

Base = declarative_base()

class AuditLog(Base):
    __tablename__ = "ai_audit_logs"

    id = Column(Integer, primary_key=True, autoincrement=True)
    request_id = Column(String(36), unique=True, nullable=False, index=True)
    timestamp = Column(DateTime, nullable=False, index=True)
    model_name = Column(String(100), nullable=False, index=True)
    model_version = Column(String(50), nullable=False)
    input_hash = Column(String(64), nullable=False)  # SHA-256 of input
    input_data = Column(JSON, nullable=False)
    output_data = Column(JSON, nullable=False)
    user_id = Column(String(100), index=True)
    latency_ms = Column(Float)
    metadata = Column(JSON)

engine = create_engine("sqlite:///ai_audit.db")
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)

import hashlib

class AuditLogger:
    def __init__(self):
        self.session = Session()

    def log(self, model_name: str, model_version: str, input_data: dict,
            output_data: dict, user_id: str = None, latency_ms: float = None,
            metadata: dict = None) -> str:
        """Record a prediction in the audit log."""
        request_id = str(uuid.uuid4())
        input_hash = hashlib.sha256(
            json.dumps(input_data, sort_keys=True).encode()
        ).hexdigest()

        record = AuditLog(
            request_id=request_id,
            timestamp=datetime.now(timezone.utc),
            model_name=model_name,
            model_version=model_version,
            input_hash=input_hash,
            input_data=input_data,
            output_data=output_data,
            user_id=user_id,
            latency_ms=latency_ms,
            metadata=metadata,
        )

        self.session.add(record)
        self.session.commit()
        return request_id

    def query_by_model(self, model_name: str, limit: int = 100) -> list:
        return self.session.query(AuditLog).filter(
            AuditLog.model_name == model_name
        ).order_by(AuditLog.timestamp.desc()).limit(limit).all()

    def query_by_user(self, user_id: str, limit: int = 100) -> list:
        return self.session.query(AuditLog).filter(
            AuditLog.user_id == user_id
        ).order_by(AuditLog.timestamp.desc()).limit(limit).all()

    def get_by_request_id(self, request_id: str) -> AuditLog:
        return self.session.query(AuditLog).filter(
            AuditLog.request_id == request_id
        ).first()

audit = AuditLogger()
rid = audit.log(
    model_name="sentiment_classifier",
    model_version="v2.1.0",
    input_data={"text": "Great product, fast shipping!"},
    output_data={"label": "positive", "score": 0.96},
    user_id="api_client_42",
    latency_ms=15.3,
)
print(f"Logged: {rid}")

The input_hash field lets you detect duplicate inputs without storing them twice, and enables efficient lookups for “what did the model predict for this exact input before?”

Middleware for Automatic Logging

Don’t rely on developers remembering to call the logger. Wrap your inference endpoints with middleware that logs automatically:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from fastapi import FastAPI, Request
from functools import wraps
import time

app = FastAPI()
audit = AuditLogger()

def audit_endpoint(model_name: str, model_version: str):
    """Decorator that automatically logs all predictions."""
    def decorator(func):
        @wraps(func)
        async def wrapper(request: Request, *args, **kwargs):
            body = await request.json()
            user_id = request.headers.get("X-User-ID", "anonymous")
            start = time.time()

            result = await func(request, *args, **kwargs)

            latency = (time.time() - start) * 1000
            audit.log(
                model_name=model_name,
                model_version=model_version,
                input_data=body,
                output_data=result,
                user_id=user_id,
                latency_ms=latency,
                metadata={
                    "ip": request.client.host,
                    "endpoint": str(request.url.path),
                },
            )

            return result
        return wrapper
    return decorator

@app.post("/predict")
@audit_endpoint(model_name="fraud_detector", model_version="v3.2.1")
async def predict(request: Request):
    body = await request.json()
    prediction = run_model(body)
    return prediction

Every request gets logged with zero changes to the prediction code.

Compliance Reports

Generate reports that auditors and regulators need:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from datetime import datetime, timedelta, timezone
from collections import Counter

class ComplianceReporter:
    def __init__(self, audit_logger: AuditLogger):
        self.audit = audit_logger

    def model_usage_report(self, model_name: str, days: int = 30) -> dict:
        """Generate a usage report for a specific model."""
        cutoff = datetime.now(timezone.utc) - timedelta(days=days)
        records = self.audit.session.query(AuditLog).filter(
            AuditLog.model_name == model_name,
            AuditLog.timestamp >= cutoff,
        ).all()

        if not records:
            return {"model": model_name, "period_days": days, "total_predictions": 0}

        predictions = [r.output_data.get("prediction", r.output_data.get("label")) for r in records]
        latencies = [r.latency_ms for r in records if r.latency_ms]
        versions = Counter(r.model_version for r in records)

        return {
            "model": model_name,
            "period_days": days,
            "total_predictions": len(records),
            "unique_users": len(set(r.user_id for r in records if r.user_id)),
            "prediction_distribution": dict(Counter(predictions)),
            "versions_used": dict(versions),
            "latency_stats": {
                "mean_ms": sum(latencies) / len(latencies) if latencies else 0,
                "p95_ms": sorted(latencies)[int(len(latencies) * 0.95)] if latencies else 0,
                "max_ms": max(latencies) if latencies else 0,
            },
            "first_prediction": min(r.timestamp for r in records).isoformat(),
            "last_prediction": max(r.timestamp for r in records).isoformat(),
        }

    def data_access_report(self, user_id: str, days: int = 90) -> dict:
        """Report all AI predictions accessed by a specific user."""
        cutoff = datetime.now(timezone.utc) - timedelta(days=days)
        records = self.audit.session.query(AuditLog).filter(
            AuditLog.user_id == user_id,
            AuditLog.timestamp >= cutoff,
        ).all()

        return {
            "user_id": user_id,
            "period_days": days,
            "total_requests": len(records),
            "models_used": dict(Counter(r.model_name for r in records)),
            "requests": [
                {
                    "request_id": r.request_id,
                    "timestamp": r.timestamp.isoformat(),
                    "model": r.model_name,
                }
                for r in records
            ],
        }

reporter = ComplianceReporter(audit)
report = reporter.model_usage_report("fraud_detector", days=30)
print(json.dumps(report, indent=2))

Input/Output Redaction

Some audit requirements conflict with privacy requirements. You need to log that a prediction happened without storing sensitive input data:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import re

REDACTION_PATTERNS = {
    "email": (r'\b[\w.+-]+@[\w-]+\.[\w.-]+\b', "[EMAIL]"),
    "phone": (r'\b\d{3}[-.]?\d{3}[-.]?\d{4}\b', "[PHONE]"),
    "ssn": (r'\b\d{3}-\d{2}-\d{4}\b', "[SSN]"),
    "credit_card": (r'\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b', "[CC]"),
}

def redact_pii(data: dict) -> dict:
    """Redact PII from audit log data while preserving structure."""
    redacted = {}
    for key, value in data.items():
        if isinstance(value, str):
            for pattern_name, (pattern, replacement) in REDACTION_PATTERNS.items():
                value = re.sub(pattern, replacement, value)
            redacted[key] = value
        elif isinstance(value, dict):
            redacted[key] = redact_pii(value)
        else:
            redacted[key] = value
    return redacted

# Log with redaction
input_data = {"text": "Contact [email protected] or call 555-123-4567", "user_ssn": "123-45-6789"}
redacted = redact_pii(input_data)
print(redacted)
# {'text': 'Contact [EMAIL] or call [PHONE]', 'user_ssn': '[SSN]'}

audit.log(
    model_name="classifier",
    model_version="v1",
    input_data=redacted,  # store redacted version
    output_data={"label": "inquiry"},
)

Common Errors and Fixes

Audit logging slows down inference

Write logs asynchronously. Use a message queue (Redis, Kafka) as a buffer between your inference service and the audit database. The prediction returns immediately while the log is written in the background.

Database grows too large

Implement a retention policy. Move records older than 1 year to cold storage (S3, GCS) and keep only recent logs in the active database. Compress JSON fields and use columnar storage (Parquet) for archived logs.

Can’t reproduce a past prediction

Log the model version and all preprocessing parameters. If your preprocessing pipeline changes between versions, the same input can produce different outputs. Pin the full pipeline version, not just the model weights.

Audit records modified after the fact

Use append-only storage. In PostgreSQL, revoke UPDATE and DELETE permissions on the audit table. For maximum tamper-resistance, hash each record and chain them (similar to a blockchain) so any modification is detectable.

Missing audit records during high traffic

Async logging can drop records if the queue fills up. Add monitoring on queue depth and set up alerts when it exceeds a threshold. Use persistent queues (Kafka, RabbitMQ with disk persistence) instead of in-memory buffers.