The Short Version: Request Spot, Checkpoint Everything#
Spot instances give you the same GPU hardware at 60-90% off on-demand prices. The catch is they can be reclaimed with as little as 2 minutes notice. The entire strategy comes down to one principle: save checkpoints frequently enough that losing a machine is a minor inconvenience, not a disaster.
Here’s what a spot request for a GPU instance looks like with boto3:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
| import boto3
ec2 = boto3.client("ec2", region_name="us-east-1")
response = ec2.request_spot_instances(
SpotPrice="3.50", # max price -- you pay the market rate, not this
InstanceCount=1,
LaunchSpecification={
"ImageId": "ami-0abcdef1234567890", # Deep Learning AMI
"InstanceType": "p4d.24xlarge", # 8x A100 40GB
"KeyName": "my-training-key",
"SecurityGroupIds": ["sg-0123456789abcdef0"],
"BlockDeviceMappings": [
{
"DeviceName": "/dev/sda1",
"Ebs": {"VolumeSize": 500, "VolumeType": "gp3"},
}
],
},
)
spot_request_id = response["SpotInstanceRequests"][0]["SpotInstanceRequestId"]
print(f"Spot request submitted: {spot_request_id}")
|
Set SpotPrice above the typical market rate. You won’t pay that amount – AWS charges the current spot price, which is usually 60-70% below on-demand. Setting a high ceiling just prevents your request from being rejected when prices briefly spike.
Automatic Checkpoint Saving and Resuming with PyTorch#
The most important piece of your spot instance setup is a checkpoint system that saves state regularly and resumes transparently when a new instance spins up. Here’s a production-ready pattern:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
| import os
import signal
import torch
from pathlib import Path
class SpotCheckpointManager:
"""Handles checkpoint save/resume for spot instance training."""
def __init__(self, checkpoint_dir: str, save_every_n_steps: int = 500):
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.save_every_n_steps = save_every_n_steps
self._interrupted = False
# Catch SIGTERM -- AWS sends this 2 min before termination
signal.signal(signal.SIGTERM, self._handle_signal)
def _handle_signal(self, signum, frame):
print(f"Received signal {signum} -- saving emergency checkpoint")
self._interrupted = True
@property
def interrupted(self) -> bool:
return self._interrupted
def save(self, model, optimizer, scheduler, step: int, epoch: int, loss: float):
path = self.checkpoint_dir / f"checkpoint_step_{step}.pt"
torch.save(
{
"step": step,
"epoch": epoch,
"loss": loss,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
},
path,
)
# Keep only the 3 most recent checkpoints to manage storage
checkpoints = sorted(self.checkpoint_dir.glob("checkpoint_step_*.pt"))
for old_ckpt in checkpoints[:-3]:
old_ckpt.unlink()
print(f"Checkpoint saved: {path}")
def load_latest(self, model, optimizer, scheduler):
checkpoints = sorted(self.checkpoint_dir.glob("checkpoint_step_*.pt"))
if not checkpoints:
print("No checkpoint found -- starting from scratch")
return 0, 0
latest = checkpoints[-1]
ckpt = torch.load(latest, map_location="cuda", weights_only=False)
model.load_state_dict(ckpt["model_state_dict"])
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
print(f"Resumed from {latest} (step {ckpt['step']}, loss {ckpt['loss']:.4f})")
return ckpt["step"], ckpt["epoch"]
|
And the training loop that uses it:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
| # Checkpoint dir should be on shared storage (EFS, S3, GCS)
ckpt_mgr = SpotCheckpointManager(
checkpoint_dir="/mnt/efs/checkpoints/run-001",
save_every_n_steps=200,
)
# Resume from last checkpoint if one exists
start_step, start_epoch = ckpt_mgr.load_latest(model, optimizer, scheduler)
global_step = start_step
for epoch in range(start_epoch, num_epochs):
for batch in dataloader:
if ckpt_mgr.interrupted:
ckpt_mgr.save(model, optimizer, scheduler, global_step, epoch, loss.item())
print("Graceful shutdown after interrupt")
raise SystemExit(0)
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
global_step += 1
if global_step % ckpt_mgr.save_every_n_steps == 0:
ckpt_mgr.save(model, optimizer, scheduler, global_step, epoch, loss.item())
|
Store checkpoints on persistent shared storage – EFS on AWS, Filestore on GCP, or sync to S3/GCS after each save. Local NVMe is fast for writing but you lose everything when the instance terminates.
Handling Spot Interruptions on AWS#
AWS gives you a 2-minute warning via the instance metadata endpoint. You can poll it or use the SIGTERM signal handler from the checkpoint manager above. Here’s a lightweight polling approach you can run as a background thread:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
| import requests
import threading
import time
def poll_spot_interruption(callback, interval: int = 5):
"""Polls the EC2 metadata endpoint for spot termination notices."""
url = "http://169.254.169.254/latest/meta-data/spot/instance-action"
token_url = "http://169.254.169.254/latest/api/token"
def _poll():
while True:
try:
# IMDSv2 requires a token
token = requests.put(
token_url,
headers={"X-aws-ec2-metadata-token-ttl-seconds": "300"},
timeout=2,
).text
resp = requests.get(
url, headers={"X-aws-ec2-metadata-token": token}, timeout=2
)
if resp.status_code == 200:
print(f"Spot interruption notice: {resp.json()}")
callback()
return
except requests.exceptions.RequestException:
pass # No interruption notice yet
time.sleep(interval)
t = threading.Thread(target=_poll, daemon=True)
t.start()
return t
|
On GCP, preemptible VMs get a 30-second warning. You need to be more aggressive with checkpoint frequency there – save every 100 steps instead of 500.
SkyPilot: Multi-Cloud Spot Orchestration#
Manually managing spot requests across clouds gets old fast. SkyPilot abstracts away the cloud provider and automatically finds the cheapest spot GPUs across AWS, GCP, and Azure. It also handles preemption recovery out of the box.
1
2
| pip install "skypilot-nightly[aws,gcp]"
sky check
|
Define your training job in a YAML file:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
| # train.yaml
resources:
accelerators: A100:8
use_spot: true
spot_recovery: failover # auto-restart on another region/cloud
file_mounts:
/data:
source: s3://my-training-data
/checkpoints:
source: s3://my-checkpoints
mode: MOUNT
setup: |
pip install torch transformers accelerate
run: |
python train.py \
--checkpoint-dir /checkpoints \
--data-dir /data \
--save-every 200
|
Launch it:
1
| sky spot launch train.yaml --name gpu-training-run
|
SkyPilot will search across all configured clouds and regions for the cheapest available spot A100x8 instance. When that instance gets preempted, SkyPilot automatically finds another one and restarts your job. Your training script just needs to handle checkpoint loading, which we already covered.
Check status and logs:
1
2
3
| sky spot queue # list running spot jobs
sky spot logs gpu-training-run # stream logs
sky spot cancel gpu-training-run # stop the job
|
Spot Instance Pools and Fallback Strategies#
Don’t pin yourself to one instance type. Spread across multiple GPU types to increase your chances of getting capacity. On AWS, you can use a spot fleet:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
| import boto3
ec2 = boto3.client("ec2", region_name="us-east-1")
response = ec2.request_spot_fleet(
SpotFleetRequestConfig={
"IamFleetRole": "arn:aws:iam::123456789012:role/spot-fleet-role",
"TargetCapacity": 1,
"AllocationStrategy": "capacityOptimized", # prioritize availability
"LaunchSpecifications": [
{
"InstanceType": "p4d.24xlarge", # 8x A100
"ImageId": "ami-0abcdef1234567890",
"KeyName": "my-key",
},
{
"InstanceType": "p3.16xlarge", # 8x V100 -- fallback
"ImageId": "ami-0abcdef1234567890",
"KeyName": "my-key",
},
{
"InstanceType": "g5.48xlarge", # 8x A10G -- cheaper fallback
"ImageId": "ami-0abcdef1234567890",
"KeyName": "my-key",
},
],
}
)
print(f"Fleet request: {response['SpotFleetRequestId']}")
|
Use capacityOptimized as the allocation strategy. It picks the pool with the most available capacity, which minimizes interruptions. The lowestPrice strategy sounds appealing but tends to land you in pools that run out of capacity quickly.
A solid fallback ladder looks like this:
- Spot A100s in your preferred region
- Spot A100s in any region
- Spot V100s or A10Gs (adjust batch size in your training config)
- On-demand A100s as a last resort (set a maximum spend limit)
Cost Comparison#
Real numbers from a recent training run (fine-tuning a 7B model, 20 hours total):
| Configuration | Instance | Hourly Cost | Total | Savings |
|---|
| On-demand | p4d.24xlarge | $32.77 | $655.40 | – |
| Spot (single region) | p4d.24xlarge | $9.83 | $196.60 | 70% |
| Spot (multi-region, SkyPilot) | p4d.24xlarge | $7.20 | $144.00 | 78% |
| Spot fallback mix | p4d/p3/g5 | $5.50 avg | $132.00 | 80% |
These numbers fluctuate. Spot prices change constantly and you’ll occasionally lose 15-30 minutes to preemption recovery. Budget for about 10-15% extra wall-clock time. Even accounting for that, spot instances are the single biggest cost lever for GPU training.
Common Errors and Fixes#
MaxSpotInstanceCountExceeded: Max spot instance count exceeded
AWS accounts have default spot vCPU limits. For GPU instances, this is often 0 for new accounts. Go to Service Quotas in the AWS console and request an increase for “All P Spot Instance Requests” and “All G and VT Spot Instance Requests.” It takes 1-3 business days.
Spot request stays in pending-evaluation forever
The instance type isn’t available as spot in your chosen region. Check current spot pricing and availability:
1
2
3
4
5
6
7
| aws ec2 describe-spot-price-history \
--instance-types p4d.24xlarge \
--product-descriptions "Linux/UNIX" \
--start-time "$(date -u +%Y-%m-%dT%H:%M:%S)" \
--region us-east-1 \
--query 'SpotPriceHistory[*].[AvailabilityZone,SpotPrice]' \
--output table
|
If there’s no pricing data, that instance type has zero spot capacity in that region. Try a different region or instance type.
Checkpoint file is corrupted after interruption
This happens when the instance terminates mid-write. Always write checkpoints atomically – save to a temp file first, then rename:
1
2
3
4
5
6
7
8
9
10
| import tempfile
def save_checkpoint_atomic(state_dict, path):
"""Write checkpoint atomically to prevent corruption."""
dir_name = os.path.dirname(path)
with tempfile.NamedTemporaryFile(dir=dir_name, delete=False) as tmp:
torch.save(state_dict, tmp.name)
tmp.flush()
os.fsync(tmp.fileno())
os.rename(tmp.name, path) # atomic on POSIX systems
|
InsufficientInstanceCapacity on spot fleet request
All pools in your fleet are out of capacity. Add more instance types and regions to your fleet config. The broader your pool, the better your odds. Also consider off-peak hours – GPU spot availability is typically highest between 10 PM and 8 AM UTC.
Training diverges after resuming from checkpoint
Make sure you’re restoring the learning rate scheduler state, not just model and optimizer. The scheduler state includes the current step count, and without it your learning rate resets to the initial value. Also verify your dataloader skips batches that were already processed – use the global_step from the checkpoint to fast-forward.
SkyPilot can’t find any available spot instances
Expand your search space. Add more clouds with sky check, or relax your GPU requirement:
1
2
3
4
5
| # Check what's available right now
sky show-gpus --all
# Allow SkyPilot to pick any A100-class GPU
sky spot launch train.yaml --gpus A100 --use-spot
|