Training jobs pile up fast when your team shares a GPU cluster. Someone kicks off a 70B fine-tune, another person needs a quick LoRA run, and suddenly everyone is waiting. A job queue with priority scheduling fixes this. Redis gives you the sorted sets and atomic operations you need, and Python’s multiprocessing handles the worker pool side.
Here’s what we’re building: a priority queue backed by Redis, Pydantic job schemas, worker processes that check GPU availability before pulling jobs, and a dead letter queue for failed runs.
Setting Up the Redis Job Queue with Priority Levels#
Redis sorted sets are perfect for priority queues. Each job gets a score — lower scores mean higher priority. You push jobs with ZADD and pop the highest-priority one with ZPOPMIN.
Install the dependencies first:
1
| pip install redis pydantic
|
Define the job schema and queue class:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
| import json
import time
import uuid
from enum import IntEnum
from typing import Optional
import redis
from pydantic import BaseModel, Field
class Priority(IntEnum):
CRITICAL = 0
HIGH = 1
NORMAL = 2
LOW = 3
class TrainingJob(BaseModel):
job_id: str = Field(default_factory=lambda: uuid.uuid4().hex[:12])
model_name: str
dataset_path: str
epochs: int = 3
batch_size: int = 32
gpu_count: int = 1
priority: Priority = Priority.NORMAL
max_retries: int = 3
retry_count: int = 0
created_at: float = Field(default_factory=time.time)
class TrainingQueue:
def __init__(self, redis_url: str = "redis://localhost:6379/0"):
self.r = redis.from_url(redis_url, decode_responses=True)
self.queue_key = "training:queue"
self.status_prefix = "training:status:"
self.dlq_key = "training:dead_letter"
def enqueue(self, job: TrainingJob) -> str:
# Score combines priority and timestamp so same-priority jobs are FIFO
score = job.priority * 1e12 + job.created_at
job_data = job.model_dump_json()
self.r.zadd(self.queue_key, {job_data: score})
self.r.set(f"{self.status_prefix}{job.job_id}", "queued")
return job.job_id
def dequeue(self) -> Optional[TrainingJob]:
result = self.r.zpopmin(self.queue_key, count=1)
if not result:
return None
job_data, _score = result[0]
job = TrainingJob.model_validate_json(job_data)
self.r.set(f"{self.status_prefix}{job.job_id}", "running")
return job
def get_status(self, job_id: str) -> Optional[str]:
return self.r.get(f"{self.status_prefix}{job_id}")
def mark_complete(self, job_id: str):
self.r.set(f"{self.status_prefix}{job_id}", "completed")
def send_to_dlq(self, job: TrainingJob, error: str):
entry = json.dumps({"job": job.model_dump(), "error": error, "failed_at": time.time()})
self.r.rpush(self.dlq_key, entry)
self.r.set(f"{self.status_prefix}{job.job_id}", "failed")
def requeue_with_retry(self, job: TrainingJob, error: str) -> bool:
if job.retry_count >= job.max_retries:
self.send_to_dlq(job, error)
return False
job.retry_count += 1
self.enqueue(job)
return True
def queue_length(self) -> int:
return self.r.zcard(self.queue_key)
|
The score formula priority * 1e12 + timestamp ensures critical jobs always run before low-priority ones, and within the same priority level, older jobs go first.
GPU Resource Tracking#
Workers shouldn’t grab a job if there aren’t enough free GPUs. Use nvidia-smi to check what’s available, then only dequeue when resources match.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
| import subprocess
import multiprocessing
def get_free_gpus() -> list[int]:
"""Return list of GPU IDs with no running processes."""
try:
result = subprocess.run(
["nvidia-smi", "--query-gpu=index,memory.used", "--format=csv,noheader,nounits"],
capture_output=True, text=True, timeout=10
)
free = []
for line in result.stdout.strip().split("\n"):
if not line.strip():
continue
idx, mem_used = line.split(",")
# GPU is "free" if less than 100 MB used
if int(mem_used.strip()) < 100:
free.append(int(idx.strip()))
return free
except (subprocess.TimeoutExpired, FileNotFoundError):
return []
def try_dequeue_for_gpus(queue: TrainingQueue) -> Optional[TrainingJob]:
"""Peek at the top job and only pop it if enough GPUs are free."""
# Peek without removing
top = queue.r.zrange(queue.queue_key, 0, 0)
if not top:
return None
job = TrainingJob.model_validate_json(top[0])
free_gpus = get_free_gpus()
if len(free_gpus) >= job.gpu_count:
return queue.dequeue()
return None
|
This peek-then-pop pattern avoids pulling a multi-GPU job when only one card is free. The job stays in the sorted set until resources open up.
Building the Worker Pool#
Each worker runs in its own process, polls the queue, and executes training. We use multiprocessing.Process instead of threads because training is CPU/GPU-bound work and you want true parallelism.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
| import os
import signal
import traceback
from multiprocessing import Process, Event
def train_model(job: TrainingJob, gpu_ids: list[int]):
"""Simulate or run actual training. Replace with your training logic."""
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(g) for g in gpu_ids[:job.gpu_count])
print(f"[Worker {os.getpid()}] Training {job.model_name} on GPUs {gpu_ids[:job.gpu_count]}")
print(f" Dataset: {job.dataset_path}, Epochs: {job.epochs}, Batch: {job.batch_size}")
# Replace this with your actual training call, e.g.:
# subprocess.run(["torchrun", "--nproc_per_node", str(job.gpu_count), "train.py", ...])
time.sleep(2) # Placeholder for actual training duration
def worker_loop(worker_id: int, redis_url: str, shutdown_event: Event):
queue = TrainingQueue(redis_url=redis_url)
print(f"[Worker {worker_id}] Started, PID={os.getpid()}")
while not shutdown_event.is_set():
job = try_dequeue_for_gpus(queue)
if job is None:
time.sleep(2)
continue
print(f"[Worker {worker_id}] Picked up job {job.job_id}: {job.model_name}")
try:
free_gpus = get_free_gpus()
train_model(job, free_gpus)
queue.mark_complete(job.job_id)
print(f"[Worker {worker_id}] Completed job {job.job_id}")
except Exception as e:
error_msg = traceback.format_exc()
print(f"[Worker {worker_id}] Job {job.job_id} failed: {e}")
retried = queue.requeue_with_retry(job, error_msg)
if retried:
print(f"[Worker {worker_id}] Requeued job {job.job_id} (retry {job.retry_count})")
else:
print(f"[Worker {worker_id}] Job {job.job_id} sent to dead letter queue")
class WorkerPool:
def __init__(self, num_workers: int = 2, redis_url: str = "redis://localhost:6379/0"):
self.num_workers = num_workers
self.redis_url = redis_url
self.shutdown_event = Event()
self.workers: list[Process] = []
def start(self):
for i in range(self.num_workers):
p = Process(target=worker_loop, args=(i, self.redis_url, self.shutdown_event))
p.daemon = True
p.start()
self.workers.append(p)
print(f"Started {self.num_workers} workers")
def stop(self, timeout: int = 10):
self.shutdown_event.set()
for p in self.workers:
p.join(timeout=timeout)
if p.is_alive():
p.terminate()
print("All workers stopped")
|
Run the full system:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
| if __name__ == "__main__":
queue = TrainingQueue()
pool = WorkerPool(num_workers=2)
# Submit some jobs
jobs = [
TrainingJob(model_name="bert-base", dataset_path="/data/sst2", priority=Priority.NORMAL),
TrainingJob(model_name="llama-7b-lora", dataset_path="/data/alpaca", priority=Priority.HIGH, gpu_count=2),
TrainingJob(model_name="vit-large", dataset_path="/data/imagenet-1k", priority=Priority.LOW, epochs=10),
]
for job in jobs:
job_id = queue.enqueue(job)
print(f"Enqueued {job.model_name} with priority {job.priority.name} -> {job_id}")
print(f"Queue length: {queue.queue_length()}")
# Start workers
pool.start()
try:
signal.signal(signal.SIGINT, lambda *_: pool.stop())
while any(p.is_alive() for p in pool.workers):
time.sleep(1)
except KeyboardInterrupt:
pool.stop()
|
The llama-7b-lora job runs first despite being submitted second because it has HIGH priority. The vit-large job waits in the back of the line.
Job Status and Dead Letter Queue Inspection#
You’ll want visibility into what’s running, what failed, and why. Here are some helper functions for inspecting queue state:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
| def inspect_queue(queue: TrainingQueue) -> list[dict]:
"""List all pending jobs with their scores."""
items = queue.r.zrange(queue.queue_key, 0, -1, withscores=True)
result = []
for job_data, score in items:
job = TrainingJob.model_validate_json(job_data)
result.append({
"job_id": job.job_id,
"model": job.model_name,
"priority": job.priority.name,
"gpu_count": job.gpu_count,
"score": score,
})
return result
def inspect_dlq(queue: TrainingQueue) -> list[dict]:
"""List all jobs in the dead letter queue."""
entries = queue.r.lrange(queue.dlq_key, 0, -1)
result = []
for entry in entries:
data = json.loads(entry)
result.append({
"job_id": data["job"]["job_id"],
"model": data["job"]["model_name"],
"error": data["error"][:200],
"failed_at": time.ctime(data["failed_at"]),
})
return result
def retry_from_dlq(queue: TrainingQueue, job_id: str) -> bool:
"""Pull a specific job from the DLQ and re-enqueue it with reset retries."""
entries = queue.r.lrange(queue.dlq_key, 0, -1)
for i, entry in enumerate(entries):
data = json.loads(entry)
if data["job"]["job_id"] == job_id:
job = TrainingJob.model_validate(data["job"])
job.retry_count = 0
queue.enqueue(job)
queue.r.lrem(queue.dlq_key, 1, entry)
return True
return False
|
Use inspect_queue to see what’s pending and inspect_dlq to review failures. The retry_from_dlq function lets you manually re-submit a failed job after fixing whatever broke it.
Common Errors and Fixes#
redis.exceptions.ConnectionError: Error 111 connecting to localhost:6379
Redis isn’t running. Start it with redis-server or systemctl start redis. If you’re using Docker: docker run -d -p 6379:6379 redis:7.
ZPOPMIN returns empty even though jobs are queued
Check that your queue_key matches across your producer and consumer code. A common mistake is connecting to a different Redis database number (e.g., /0 vs /1 in the URL). Verify with redis-cli ZCARD training:queue.
Workers dequeue jobs but GPUs show as busy
The nvidia-smi memory threshold of 100 MB may be too low. Some drivers report baseline memory usage around 200-300 MB even with no processes. Bump the threshold in get_free_gpus or switch to checking gpu_util instead of memory.used.
pydantic.ValidationError when deserializing jobs
This happens when you change the TrainingJob schema after jobs are already in the queue. Old serialized jobs won’t match the new schema. Either drain the queue before deploying schema changes, or add Optional defaults to new fields so old payloads still validate.
Race condition: two workers grab the same job
ZPOPMIN is atomic in Redis, so two workers won’t pop the same entry. But the peek-then-pop pattern in try_dequeue_for_gpus has a small window where both workers peek, both see free GPUs, and both try to dequeue. The second dequeue call gets a different job (or None), so you won’t get duplicate execution. If you need stricter guarantees, wrap the peek-and-pop in a Redis Lua script or use WATCH/MULTI transactions.
Dead letter queue grows unbounded
Set a TTL or max length on your DLQ. Add a periodic cleanup: queue.r.ltrim(queue.dlq_key, 0, 999) keeps only the latest 1000 failures. Or push DLQ entries with timestamps and purge anything older than 7 days.