When multiple people share a GPU box (or a small cluster), training jobs collide. Someone launches a 13B fine-tune that eats 80 GB of VRAM, and now your quick eval job sits there doing nothing. You need a scheduler — something that checks real GPU memory, queues jobs by priority, and launches them when resources free up.

We’ll build one with Redis for the priority queue, pynvml for GPU memory checks, subprocess management for launching training runs, and FastAPI for a REST API to submit and monitor jobs.

Install the dependencies:

1
pip install redis pynvml fastapi uvicorn pydantic

The Priority Queue with Redis

Redis sorted sets give you a natural priority queue. Lower scores pop first. We’ll encode priority level plus submission time into the score so that within the same priority, older jobs run first.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import json
import time
import uuid
from enum import IntEnum
from dataclasses import dataclass, field, asdict

import redis

class Priority(IntEnum):
    CRITICAL = 0
    HIGH = 100
    NORMAL = 200
    LOW = 300

@dataclass
class TrainingJob:
    model_script: str
    gpu_memory_mb: int
    num_gpus: int = 1
    priority: int = Priority.NORMAL
    job_id: str = field(default_factory=lambda: uuid.uuid4().hex[:12])
    status: str = "queued"
    assigned_gpus: list[int] = field(default_factory=list)
    created_at: float = field(default_factory=time.time)
    pid: int | None = None

class JobQueue:
    def __init__(self, redis_url: str = "redis://localhost:6379/0"):
        self.r = redis.from_url(redis_url, decode_responses=True)
        self.queue_key = "scheduler:queue"
        self.job_prefix = "scheduler:job:"

    def submit(self, job: TrainingJob) -> str:
        # Score = priority * 1e10 + timestamp, so same-priority jobs are FIFO
        score = job.priority * 1e10 + job.created_at
        self.r.zadd(self.queue_key, {job.job_id: score})
        self.r.set(self.job_prefix + job.job_id, json.dumps(asdict(job)))
        return job.job_id

    def pop_next(self) -> TrainingJob | None:
        result = self.r.zpopmin(self.queue_key, count=1)
        if not result:
            return None
        job_id, _score = result[0]
        raw = self.r.get(self.job_prefix + job_id)
        if not raw:
            return None
        return TrainingJob(**json.loads(raw))

    def peek(self, count: int = 10) -> list[TrainingJob]:
        entries = self.r.zrange(self.queue_key, 0, count - 1)
        jobs = []
        for job_id in entries:
            raw = self.r.get(self.job_prefix + job_id)
            if raw:
                jobs.append(TrainingJob(**json.loads(raw)))
        return jobs

    def update_job(self, job: TrainingJob) -> None:
        self.r.set(self.job_prefix + job.job_id, json.dumps(asdict(job)))

    def queue_length(self) -> int:
        return self.r.zcard(self.queue_key)

The score trick is the key detail. priority * 1e10 + timestamp means a CRITICAL job submitted 10 minutes ago still beats a NORMAL job submitted a week ago, but two NORMAL jobs respect submission order.

GPU Availability with pynvml

Before launching a job, you need to know which GPUs have enough free memory. pynvml talks directly to the NVIDIA driver — no shelling out to nvidia-smi.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import pynvml

def get_gpu_free_memory() -> list[dict]:
    """Return free memory in MB for each GPU."""
    pynvml.nvmlInit()
    gpu_count = pynvml.nvmlDeviceGetCount()
    gpus = []
    for i in range(gpu_count):
        handle = pynvml.nvmlDeviceGetHandleByIndex(i)
        mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        gpus.append({
            "index": i,
            "free_mb": mem_info.free // (1024 * 1024),
            "total_mb": mem_info.total // (1024 * 1024),
            "used_mb": mem_info.used // (1024 * 1024),
        })
    pynvml.nvmlShutdown()
    return gpus

def find_available_gpus(required_mb: int, num_gpus: int) -> list[int] | None:
    """Find GPUs with enough free memory. Returns GPU indices or None."""
    gpus = get_gpu_free_memory()
    # Sort by free memory descending — prefer emptiest GPUs
    gpus.sort(key=lambda g: g["free_mb"], reverse=True)
    candidates = [g["index"] for g in gpus if g["free_mb"] >= required_mb]
    if len(candidates) >= num_gpus:
        return candidates[:num_gpus]
    return None

This checks actual free memory, not just whether a GPU is “in use.” A GPU running a small inference server with 4 GB used on a 24 GB card still has 20 GB free for a training job. The function sorts by free memory so jobs land on the least-loaded GPUs first.

The Scheduler Loop

The scheduler runs as a background process. It polls the queue, checks GPU availability, and launches jobs as subprocesses. It also tracks running jobs and cleans up when they finish.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import subprocess
import os
import signal
import time

class TrainingScheduler:
    def __init__(self, queue: JobQueue, poll_interval: float = 5.0):
        self.queue = queue
        self.poll_interval = poll_interval
        self.running_jobs: dict[str, subprocess.Popen] = {}

    def launch_job(self, job: TrainingJob, gpu_ids: list[int]) -> subprocess.Popen:
        env = os.environ.copy()
        env["CUDA_VISIBLE_DEVICES"] = ",".join(str(g) for g in gpu_ids)
        proc = subprocess.Popen(
            ["python", job.model_script],
            env=env,
            stdout=open(f"/tmp/train_{job.job_id}.log", "w"),
            stderr=subprocess.STDOUT,
        )
        return proc

    def check_running_jobs(self) -> None:
        finished = []
        for job_id, proc in self.running_jobs.items():
            retcode = proc.poll()
            if retcode is not None:
                status = "completed" if retcode == 0 else "failed"
                raw = self.queue.r.get(self.queue.job_prefix + job_id)
                if raw:
                    job = TrainingJob(**json.loads(raw))
                    job.status = status
                    self.queue.update_job(job)
                finished.append(job_id)
                print(f"Job {job_id} {status} (exit code {retcode})")
        for job_id in finished:
            del self.running_jobs[job_id]

    def preempt_for_critical(self, critical_job: TrainingJob) -> bool:
        """Kill the lowest-priority running job to free GPUs for a critical job."""
        if not self.running_jobs:
            return False

        # Find the lowest-priority running job
        worst_job_id = None
        worst_priority = -1
        for job_id in self.running_jobs:
            raw = self.queue.r.get(self.queue.job_prefix + job_id)
            if raw:
                running = TrainingJob(**json.loads(raw))
                if running.priority > worst_priority:
                    worst_priority = running.priority
                    worst_job_id = job_id

        # Only preempt if the running job has lower priority
        if worst_job_id and worst_priority > critical_job.priority:
            proc = self.running_jobs[worst_job_id]
            proc.send_signal(signal.SIGTERM)
            proc.wait(timeout=30)
            raw = self.queue.r.get(self.queue.job_prefix + worst_job_id)
            if raw:
                preempted = TrainingJob(**json.loads(raw))
                preempted.status = "queued"
                preempted.pid = None
                self.queue.submit(preempted)  # Re-queue it
            del self.running_jobs[worst_job_id]
            print(f"Preempted job {worst_job_id} for critical job {critical_job.job_id}")
            return True
        return False

    def run(self) -> None:
        print("Scheduler started. Polling queue...")
        while True:
            self.check_running_jobs()

            job = self.queue.pop_next()
            if job is None:
                time.sleep(self.poll_interval)
                continue

            gpu_ids = find_available_gpus(job.gpu_memory_mb, job.num_gpus)

            if gpu_ids is None and job.priority == Priority.CRITICAL:
                if self.preempt_for_critical(job):
                    gpu_ids = find_available_gpus(job.gpu_memory_mb, job.num_gpus)

            if gpu_ids is None:
                # No GPUs free — re-queue and wait
                self.queue.submit(job)
                time.sleep(self.poll_interval)
                continue

            job.status = "running"
            job.assigned_gpus = gpu_ids
            proc = self.launch_job(job, gpu_ids)
            job.pid = proc.pid
            self.queue.update_job(job)
            self.running_jobs[job.job_id] = proc
            print(f"Launched job {job.job_id} on GPUs {gpu_ids} (PID {proc.pid})")

The preemption logic only kicks in for CRITICAL jobs, and only kills a lower-priority job. The preempted job goes back into the queue at its original priority so it gets picked up again once GPUs free up.

FastAPI REST API for Job Submission

You want a REST endpoint so teammates can submit jobs from notebooks or scripts without touching Redis directly. FastAPI’s lifespan context manager handles starting and stopping the scheduler cleanly.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import threading
from contextlib import asynccontextmanager

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

# Global references set during lifespan
scheduler: TrainingScheduler | None = None
queue: JobQueue | None = None

class JobRequest(BaseModel):
    model_script: str
    gpu_memory_mb: int
    num_gpus: int = 1
    priority: int = Priority.NORMAL

class JobResponse(BaseModel):
    job_id: str
    status: str
    assigned_gpus: list[int]
    priority: int

@asynccontextmanager
async def lifespan(app: FastAPI):
    global scheduler, queue
    queue = JobQueue()
    scheduler = TrainingScheduler(queue)
    # Run scheduler loop in a background thread
    thread = threading.Thread(target=scheduler.run, daemon=True)
    thread.start()
    yield
    # Cleanup: scheduler thread dies with the process since it's a daemon

app = FastAPI(title="Training Scheduler", lifespan=lifespan)

@app.post("/jobs", response_model=JobResponse)
def submit_job(req: JobRequest):
    job = TrainingJob(
        model_script=req.model_script,
        gpu_memory_mb=req.gpu_memory_mb,
        num_gpus=req.num_gpus,
        priority=req.priority,
    )
    queue.submit(job)
    return JobResponse(
        job_id=job.job_id,
        status=job.status,
        assigned_gpus=job.assigned_gpus,
        priority=job.priority,
    )

@app.get("/jobs/{job_id}", response_model=JobResponse)
def get_job(job_id: str):
    raw = queue.r.get(queue.job_prefix + job_id)
    if not raw:
        raise HTTPException(status_code=404, detail="Job not found")
    job = TrainingJob(**json.loads(raw))
    return JobResponse(
        job_id=job.job_id,
        status=job.status,
        assigned_gpus=job.assigned_gpus,
        priority=job.priority,
    )

@app.get("/queue")
def list_queue():
    jobs = queue.peek(count=50)
    return {
        "queue_length": queue.queue_length(),
        "running": len(scheduler.running_jobs),
        "jobs": [
            {"job_id": j.job_id, "priority": j.priority, "status": j.status}
            for j in jobs
        ],
    }

@app.get("/gpus")
def gpu_status():
    return get_gpu_free_memory()

Start it with:

1
uvicorn scheduler_api:app --host 0.0.0.0 --port 8000

Then submit a job from anywhere:

1
2
3
curl -X POST http://localhost:8000/jobs \
  -H "Content-Type: application/json" \
  -d '{"model_script": "train_lora.py", "gpu_memory_mb": 16000, "num_gpus": 1, "priority": 100}'

Check GPU status and queue depth:

1
2
curl http://localhost:8000/gpus
curl http://localhost:8000/queue

Common Errors and Fixes

pynvml fails with NVMLError_LibraryNotFound

The NVIDIA driver isn’t installed or libnvidia-ml.so isn’t on the library path. Verify with:

1
ldconfig -p | grep libnvidia-ml

If it’s missing, install the NVIDIA driver (not just the CUDA toolkit). On Ubuntu: apt install nvidia-driver-550.

Redis connection refused

Make sure Redis is running. The fastest way to get it up:

1
docker run -d --name redis -p 6379:6379 redis:7-alpine

Jobs stuck in “queued” forever

This usually means gpu_memory_mb in the job request is higher than any GPU’s actual free memory. Check with the /gpus endpoint. Also check that CUDA_VISIBLE_DEVICES isn’t already set in the scheduler’s environment — that would hide GPUs from pynvml.

Preempted job won’t restart

If the training script doesn’t handle SIGTERM gracefully, the process might leave GPU memory allocated until the driver cleans it up. Add a signal handler in your training scripts:

1
2
3
4
5
6
7
8
9
import signal
import sys

def handle_sigterm(signum, frame):
    print("Received SIGTERM, saving checkpoint and exiting...")
    # Save your checkpoint here
    sys.exit(0)

signal.signal(signal.SIGTERM, handle_sigterm)

File descriptor leak from subprocess logs

The open() call for log files in launch_job creates file handles that stick around. In production, wrap the launch in a context manager or store the file objects and close them in check_running_jobs when the process finishes.