When multiple people share a GPU box (or a small cluster), training jobs collide. Someone launches a 13B fine-tune that eats 80 GB of VRAM, and now your quick eval job sits there doing nothing. You need a scheduler — something that checks real GPU memory, queues jobs by priority, and launches them when resources free up.
We’ll build one with Redis for the priority queue, pynvml for GPU memory checks, subprocess management for launching training runs, and FastAPI for a REST API to submit and monitor jobs.
Install the dependencies:
1
| pip install redis pynvml fastapi uvicorn pydantic
|
The Priority Queue with Redis#
Redis sorted sets give you a natural priority queue. Lower scores pop first. We’ll encode priority level plus submission time into the score so that within the same priority, older jobs run first.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
| import json
import time
import uuid
from enum import IntEnum
from dataclasses import dataclass, field, asdict
import redis
class Priority(IntEnum):
CRITICAL = 0
HIGH = 100
NORMAL = 200
LOW = 300
@dataclass
class TrainingJob:
model_script: str
gpu_memory_mb: int
num_gpus: int = 1
priority: int = Priority.NORMAL
job_id: str = field(default_factory=lambda: uuid.uuid4().hex[:12])
status: str = "queued"
assigned_gpus: list[int] = field(default_factory=list)
created_at: float = field(default_factory=time.time)
pid: int | None = None
class JobQueue:
def __init__(self, redis_url: str = "redis://localhost:6379/0"):
self.r = redis.from_url(redis_url, decode_responses=True)
self.queue_key = "scheduler:queue"
self.job_prefix = "scheduler:job:"
def submit(self, job: TrainingJob) -> str:
# Score = priority * 1e10 + timestamp, so same-priority jobs are FIFO
score = job.priority * 1e10 + job.created_at
self.r.zadd(self.queue_key, {job.job_id: score})
self.r.set(self.job_prefix + job.job_id, json.dumps(asdict(job)))
return job.job_id
def pop_next(self) -> TrainingJob | None:
result = self.r.zpopmin(self.queue_key, count=1)
if not result:
return None
job_id, _score = result[0]
raw = self.r.get(self.job_prefix + job_id)
if not raw:
return None
return TrainingJob(**json.loads(raw))
def peek(self, count: int = 10) -> list[TrainingJob]:
entries = self.r.zrange(self.queue_key, 0, count - 1)
jobs = []
for job_id in entries:
raw = self.r.get(self.job_prefix + job_id)
if raw:
jobs.append(TrainingJob(**json.loads(raw)))
return jobs
def update_job(self, job: TrainingJob) -> None:
self.r.set(self.job_prefix + job.job_id, json.dumps(asdict(job)))
def queue_length(self) -> int:
return self.r.zcard(self.queue_key)
|
The score trick is the key detail. priority * 1e10 + timestamp means a CRITICAL job submitted 10 minutes ago still beats a NORMAL job submitted a week ago, but two NORMAL jobs respect submission order.
GPU Availability with pynvml#
Before launching a job, you need to know which GPUs have enough free memory. pynvml talks directly to the NVIDIA driver — no shelling out to nvidia-smi.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
| import pynvml
def get_gpu_free_memory() -> list[dict]:
"""Return free memory in MB for each GPU."""
pynvml.nvmlInit()
gpu_count = pynvml.nvmlDeviceGetCount()
gpus = []
for i in range(gpu_count):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
gpus.append({
"index": i,
"free_mb": mem_info.free // (1024 * 1024),
"total_mb": mem_info.total // (1024 * 1024),
"used_mb": mem_info.used // (1024 * 1024),
})
pynvml.nvmlShutdown()
return gpus
def find_available_gpus(required_mb: int, num_gpus: int) -> list[int] | None:
"""Find GPUs with enough free memory. Returns GPU indices or None."""
gpus = get_gpu_free_memory()
# Sort by free memory descending — prefer emptiest GPUs
gpus.sort(key=lambda g: g["free_mb"], reverse=True)
candidates = [g["index"] for g in gpus if g["free_mb"] >= required_mb]
if len(candidates) >= num_gpus:
return candidates[:num_gpus]
return None
|
This checks actual free memory, not just whether a GPU is “in use.” A GPU running a small inference server with 4 GB used on a 24 GB card still has 20 GB free for a training job. The function sorts by free memory so jobs land on the least-loaded GPUs first.
The Scheduler Loop#
The scheduler runs as a background process. It polls the queue, checks GPU availability, and launches jobs as subprocesses. It also tracks running jobs and cleans up when they finish.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
| import subprocess
import os
import signal
import time
class TrainingScheduler:
def __init__(self, queue: JobQueue, poll_interval: float = 5.0):
self.queue = queue
self.poll_interval = poll_interval
self.running_jobs: dict[str, subprocess.Popen] = {}
def launch_job(self, job: TrainingJob, gpu_ids: list[int]) -> subprocess.Popen:
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = ",".join(str(g) for g in gpu_ids)
proc = subprocess.Popen(
["python", job.model_script],
env=env,
stdout=open(f"/tmp/train_{job.job_id}.log", "w"),
stderr=subprocess.STDOUT,
)
return proc
def check_running_jobs(self) -> None:
finished = []
for job_id, proc in self.running_jobs.items():
retcode = proc.poll()
if retcode is not None:
status = "completed" if retcode == 0 else "failed"
raw = self.queue.r.get(self.queue.job_prefix + job_id)
if raw:
job = TrainingJob(**json.loads(raw))
job.status = status
self.queue.update_job(job)
finished.append(job_id)
print(f"Job {job_id} {status} (exit code {retcode})")
for job_id in finished:
del self.running_jobs[job_id]
def preempt_for_critical(self, critical_job: TrainingJob) -> bool:
"""Kill the lowest-priority running job to free GPUs for a critical job."""
if not self.running_jobs:
return False
# Find the lowest-priority running job
worst_job_id = None
worst_priority = -1
for job_id in self.running_jobs:
raw = self.queue.r.get(self.queue.job_prefix + job_id)
if raw:
running = TrainingJob(**json.loads(raw))
if running.priority > worst_priority:
worst_priority = running.priority
worst_job_id = job_id
# Only preempt if the running job has lower priority
if worst_job_id and worst_priority > critical_job.priority:
proc = self.running_jobs[worst_job_id]
proc.send_signal(signal.SIGTERM)
proc.wait(timeout=30)
raw = self.queue.r.get(self.queue.job_prefix + worst_job_id)
if raw:
preempted = TrainingJob(**json.loads(raw))
preempted.status = "queued"
preempted.pid = None
self.queue.submit(preempted) # Re-queue it
del self.running_jobs[worst_job_id]
print(f"Preempted job {worst_job_id} for critical job {critical_job.job_id}")
return True
return False
def run(self) -> None:
print("Scheduler started. Polling queue...")
while True:
self.check_running_jobs()
job = self.queue.pop_next()
if job is None:
time.sleep(self.poll_interval)
continue
gpu_ids = find_available_gpus(job.gpu_memory_mb, job.num_gpus)
if gpu_ids is None and job.priority == Priority.CRITICAL:
if self.preempt_for_critical(job):
gpu_ids = find_available_gpus(job.gpu_memory_mb, job.num_gpus)
if gpu_ids is None:
# No GPUs free — re-queue and wait
self.queue.submit(job)
time.sleep(self.poll_interval)
continue
job.status = "running"
job.assigned_gpus = gpu_ids
proc = self.launch_job(job, gpu_ids)
job.pid = proc.pid
self.queue.update_job(job)
self.running_jobs[job.job_id] = proc
print(f"Launched job {job.job_id} on GPUs {gpu_ids} (PID {proc.pid})")
|
The preemption logic only kicks in for CRITICAL jobs, and only kills a lower-priority job. The preempted job goes back into the queue at its original priority so it gets picked up again once GPUs free up.
FastAPI REST API for Job Submission#
You want a REST endpoint so teammates can submit jobs from notebooks or scripts without touching Redis directly. FastAPI’s lifespan context manager handles starting and stopping the scheduler cleanly.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
| import threading
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
# Global references set during lifespan
scheduler: TrainingScheduler | None = None
queue: JobQueue | None = None
class JobRequest(BaseModel):
model_script: str
gpu_memory_mb: int
num_gpus: int = 1
priority: int = Priority.NORMAL
class JobResponse(BaseModel):
job_id: str
status: str
assigned_gpus: list[int]
priority: int
@asynccontextmanager
async def lifespan(app: FastAPI):
global scheduler, queue
queue = JobQueue()
scheduler = TrainingScheduler(queue)
# Run scheduler loop in a background thread
thread = threading.Thread(target=scheduler.run, daemon=True)
thread.start()
yield
# Cleanup: scheduler thread dies with the process since it's a daemon
app = FastAPI(title="Training Scheduler", lifespan=lifespan)
@app.post("/jobs", response_model=JobResponse)
def submit_job(req: JobRequest):
job = TrainingJob(
model_script=req.model_script,
gpu_memory_mb=req.gpu_memory_mb,
num_gpus=req.num_gpus,
priority=req.priority,
)
queue.submit(job)
return JobResponse(
job_id=job.job_id,
status=job.status,
assigned_gpus=job.assigned_gpus,
priority=job.priority,
)
@app.get("/jobs/{job_id}", response_model=JobResponse)
def get_job(job_id: str):
raw = queue.r.get(queue.job_prefix + job_id)
if not raw:
raise HTTPException(status_code=404, detail="Job not found")
job = TrainingJob(**json.loads(raw))
return JobResponse(
job_id=job.job_id,
status=job.status,
assigned_gpus=job.assigned_gpus,
priority=job.priority,
)
@app.get("/queue")
def list_queue():
jobs = queue.peek(count=50)
return {
"queue_length": queue.queue_length(),
"running": len(scheduler.running_jobs),
"jobs": [
{"job_id": j.job_id, "priority": j.priority, "status": j.status}
for j in jobs
],
}
@app.get("/gpus")
def gpu_status():
return get_gpu_free_memory()
|
Start it with:
1
| uvicorn scheduler_api:app --host 0.0.0.0 --port 8000
|
Then submit a job from anywhere:
1
2
3
| curl -X POST http://localhost:8000/jobs \
-H "Content-Type: application/json" \
-d '{"model_script": "train_lora.py", "gpu_memory_mb": 16000, "num_gpus": 1, "priority": 100}'
|
Check GPU status and queue depth:
1
2
| curl http://localhost:8000/gpus
curl http://localhost:8000/queue
|
Common Errors and Fixes#
pynvml fails with NVMLError_LibraryNotFound
The NVIDIA driver isn’t installed or libnvidia-ml.so isn’t on the library path. Verify with:
1
| ldconfig -p | grep libnvidia-ml
|
If it’s missing, install the NVIDIA driver (not just the CUDA toolkit). On Ubuntu: apt install nvidia-driver-550.
Redis connection refused
Make sure Redis is running. The fastest way to get it up:
1
| docker run -d --name redis -p 6379:6379 redis:7-alpine
|
Jobs stuck in “queued” forever
This usually means gpu_memory_mb in the job request is higher than any GPU’s actual free memory. Check with the /gpus endpoint. Also check that CUDA_VISIBLE_DEVICES isn’t already set in the scheduler’s environment — that would hide GPUs from pynvml.
Preempted job won’t restart
If the training script doesn’t handle SIGTERM gracefully, the process might leave GPU memory allocated until the driver cleans it up. Add a signal handler in your training scripts:
1
2
3
4
5
6
7
8
9
| import signal
import sys
def handle_sigterm(signum, frame):
print("Received SIGTERM, saving checkpoint and exiting...")
# Save your checkpoint here
sys.exit(0)
signal.signal(signal.SIGTERM, handle_sigterm)
|
File descriptor leak from subprocess logs
The open() call for log files in launch_job creates file handles that stick around. In production, wrap the launch in a context manager or store the file objects and close them in check_running_jobs when the process finishes.