When your API gets a burst of inference requests, you have two options: block every request until the model finishes, or drop them into a queue and process them asynchronously. The queue approach wins every time. Celery with Redis as the broker gives you distributed task processing with automatic retries, priority routing, and result storage – all without building your own job scheduler.

1
pip install celery[redis] redis torch torchvision

You need a running Redis instance. If you have Docker:

1
docker run -d --name redis-inference -p 6379:6379 redis:7-alpine

Setting Up Celery with Redis

Create a celery_app.py file that configures Celery to use Redis as both the message broker and result backend.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
from celery import Celery

app = Celery("inference")

app.conf.update(
    broker_url="redis://localhost:6379/0",
    result_backend="redis://localhost:6379/1",
    task_serializer="json",
    result_serializer="json",
    accept_content=["json"],
    task_track_started=True,
    task_acks_late=True,          # re-queue tasks if a worker crashes mid-inference
    worker_prefetch_multiplier=1, # grab one task at a time (important for GPU workers)
    result_expires=3600,          # results expire after 1 hour
)

Three settings matter for ML workloads. task_acks_late=True means the task stays in the queue until it finishes, so if a worker dies mid-inference the task gets picked up by another worker. worker_prefetch_multiplier=1 prevents workers from grabbing multiple tasks when each one needs the full GPU. And result_expires keeps your Redis memory from growing unbounded with old predictions.

Building the Inference Task

The trick is loading the model once when the worker process starts, not on every task invocation. Use Celery’s worker_process_init signal to load the model into a global variable.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
from celery import Celery
from celery.signals import worker_process_init

app = Celery("inference")
app.conf.update(
    broker_url="redis://localhost:6379/0",
    result_backend="redis://localhost:6379/1",
    task_serializer="json",
    result_serializer="json",
    accept_content=["json"],
    task_track_started=True,
    task_acks_late=True,
    worker_prefetch_multiplier=1,
    result_expires=3600,
)

model = None
device = None


@worker_process_init.connect
def load_model(**kwargs):
    global model, device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = torch.jit.load("model_scripted.pt", map_location=device)
    model.eval()
    print(f"Model loaded on {device}")


@app.task(bind=True, max_retries=2, default_retry_delay=5)
def predict(self, input_data: list[float]) -> dict:
    """Run inference on a single input and return the prediction."""
    try:
        tensor = torch.tensor([input_data], dtype=torch.float32).to(device)
        with torch.no_grad():
            output = model(tensor)
        probabilities = torch.softmax(output, dim=1)
        predicted_class = torch.argmax(probabilities, dim=1).item()
        confidence = probabilities[0][predicted_class].item()
        return {
            "class": predicted_class,
            "confidence": round(confidence, 4),
            "device": str(device),
        }
    except RuntimeError as exc:
        raise self.retry(exc=exc)

Start the worker with one process per GPU (concurrency=1 per process for GPU work):

1
celery -A celery_app worker --pool=solo --concurrency=1 --loglevel=info -n worker1@%h

The --pool=solo flag runs tasks in the main process, which avoids forking issues with CUDA contexts. Each worker process owns one GPU. To use multiple GPUs, start multiple workers and set CUDA_VISIBLE_DEVICES for each:

1
2
CUDA_VISIBLE_DEVICES=0 celery -A celery_app worker --pool=solo -n gpu0@%h &
CUDA_VISIBLE_DEVICES=1 celery -A celery_app worker --pool=solo -n gpu1@%h &

Submitting Tasks from Your API

Here is a FastAPI endpoint that queues inference and returns the task ID:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
from fastapi import FastAPI
from celery.result import AsyncResult
from celery_app import predict

api = FastAPI()


@api.post("/predict")
def submit_prediction(input_data: list[float]):
    task = predict.delay(input_data)
    return {"task_id": task.id}


@api.get("/result/{task_id}")
def get_result(task_id: str):
    result = AsyncResult(task_id)
    if result.ready():
        return {"status": "completed", "result": result.get()}
    return {"status": result.state}

Priority Queues for Different SLAs

Not all inference requests are equal. Real-time API calls need sub-second latency. Batch scoring jobs can wait. Set up separate queues with different priorities so urgent requests jump ahead.

Define the queues in your Celery config:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
from kombu import Queue

app.conf.update(
    task_queues=(
        Queue("realtime", routing_key="realtime", queue_arguments={"x-max-priority": 10}),
        Queue("batch", routing_key="batch", queue_arguments={"x-max-priority": 5}),
        Queue("default", routing_key="default"),
    ),
    task_default_queue="default",
    task_default_routing_key="default",
    task_routes={
        "celery_app.predict": {"queue": "realtime"},
        "celery_app.batch_predict": {"queue": "batch"},
    },
)

Now define a batch prediction task that processes multiple inputs in one pass:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
@app.task(bind=True, max_retries=1)
def batch_predict(self, input_batch: list[list[float]]) -> list[dict]:
    """Run inference on a batch of inputs."""
    try:
        tensor = torch.tensor(input_batch, dtype=torch.float32).to(device)
        with torch.no_grad():
            output = model(tensor)
        probabilities = torch.softmax(output, dim=1)
        predicted_classes = torch.argmax(probabilities, dim=1).tolist()
        confidences = probabilities.max(dim=1).values.tolist()
        return [
            {"class": cls, "confidence": round(conf, 4)}
            for cls, conf in zip(predicted_classes, confidences)
        ]
    except RuntimeError as exc:
        raise self.retry(exc=exc)

Start dedicated workers for each queue:

1
2
3
4
5
# Fast workers for real-time queue
CUDA_VISIBLE_DEVICES=0 celery -A celery_app worker --pool=solo -Q realtime -n rt-worker@%h

# Separate workers for batch queue
CUDA_VISIBLE_DEVICES=1 celery -A celery_app worker --pool=solo -Q batch -n batch-worker@%h

To send a task with explicit priority:

1
2
3
4
5
# High priority -- gets processed first
predict.apply_async(args=[[0.1, 0.5, 0.3]], queue="realtime", priority=9)

# Low priority batch job
batch_predict.apply_async(args=[large_input_list], queue="batch", priority=1)

Monitoring and Scaling Workers

Flower is the standard monitoring tool for Celery. It gives you a web dashboard with worker status, task rates, and queue depths.

1
2
pip install flower
celery -A celery_app flower --port=5555

Open http://localhost:5555 to see active workers, task success/failure rates, and queue lengths.

For programmatic monitoring, query Celery’s inspection API directly:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
from celery_app import app

inspector = app.control.inspect()

# Check which workers are online
active_workers = inspector.active_queues()
print(f"Active workers: {list(active_workers.keys()) if active_workers else 'none'}")

# Count tasks waiting in each queue
with app.connection_or_acquire() as conn:
    for queue_name in ["realtime", "batch", "default"]:
        try:
            queue = conn.default_channel.queue_declare(queue=queue_name, passive=True)
            print(f"Queue '{queue_name}': {queue.message_count} pending tasks")
        except Exception:
            print(f"Queue '{queue_name}': does not exist yet")

# Check what each worker is currently running
active_tasks = inspector.active()
if active_tasks:
    for worker, tasks in active_tasks.items():
        print(f"{worker}: {len(tasks)} active tasks")

Autoscaling Workers

Celery has a built-in autoscaler that adjusts concurrency based on queue load. For CPU workers this works well:

1
celery -A celery_app worker --autoscale=8,2 -Q batch -n batch-auto@%h

This scales between 2 and 8 concurrent processes. For GPU workers, autoscaling concurrency per process does not make sense – you scale by adding more worker processes. A simple shell script that watches queue depth:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
#!/bin/bash
QUEUE="realtime"
THRESHOLD=50

while true; do
    DEPTH=$(redis-cli LLEN "$QUEUE" 2>/dev/null || echo 0)
    WORKERS=$(celery -A celery_app inspect active_queues --json 2>/dev/null | python3 -c "
import sys, json
data = json.load(sys.stdin)
count = sum(1 for w in data.values() for q in w if q['name'] == '$QUEUE')
print(count)
" 2>/dev/null || echo 0)
    echo "Queue depth: $DEPTH, Workers: $WORKERS"
    if [ "$DEPTH" -gt "$THRESHOLD" ] && [ "$WORKERS" -lt 4 ]; then
        echo "Scaling up -- starting additional worker"
        CUDA_VISIBLE_DEVICES=$WORKERS celery -A celery_app worker --pool=solo -Q "$QUEUE" \
            -n "autoscaled-${WORKERS}@%h" --detach
    fi
    sleep 10
done

Common Errors and Fixes

1. kombu.exceptions.OperationalError: Error 111 connecting to localhost:6379. Connection refused.

Redis is not running. Start it:

1
2
3
docker run -d --name redis-inference -p 6379:6379 redis:7-alpine
# or if using system Redis:
sudo systemctl start redis

Also check that broker_url in your Celery config points to the right host and port. Inside Docker containers, localhost means the container itself, not the host machine – use the Docker network IP or host.docker.internal.

2. RuntimeError: Cannot re-initialize CUDA in forked subprocess.

This happens when Celery forks worker processes after CUDA has been initialized. The fix is to use --pool=solo or --pool=threads instead of the default prefork pool:

1
2
3
4
5
# Wrong -- prefork pool forks after import, breaks CUDA
celery -A celery_app worker --concurrency=4

# Correct -- solo pool, one task at a time per process
celery -A celery_app worker --pool=solo --concurrency=1

If you need multiple GPU workers, start separate processes rather than relying on fork-based concurrency.

3. celery.exceptions.TimeLimitExceeded on large batch inputs

The default time limit kills tasks that run too long. Set explicit time limits for batch tasks:

1
2
3
4
@app.task(bind=True, time_limit=300, soft_time_limit=270)
def batch_predict(self, input_batch: list[list[float]]) -> list[dict]:
    # ... inference code ...
    pass

The soft_time_limit raises SoftTimeLimitExceeded inside your task so you can save partial results. The hard time_limit kills the process if the soft limit did not work. Set these based on your maximum expected batch size – profile your model to find how long the largest realistic batch takes and add a buffer.

4. redis.exceptions.ResponseError: OOM command not allowed when used memory > 'maxmemory'

Redis ran out of memory, usually from too many stored results. Lower the result_expires setting or set a maxmemory policy:

1
2
redis-cli CONFIG SET maxmemory 2gb
redis-cli CONFIG SET maxmemory-policy allkeys-lru

Also consider whether you actually need the result backend. If your API polls for results, you need it. If you are just doing fire-and-forget batch processing, disable it to save memory:

1
2
3
app.conf.update(
    result_backend=None,  # no result storage
)