Training a large model for 48 hours and losing everything to a preempted spot instance is a rite of passage nobody wants. Checkpoints fix that. But most tutorials stop at torch.save() and call it a day. A real checkpoint pipeline handles automatic intervals, disk rotation, crash recovery, and background uploads to object storage.
Here’s the minimal version. Save the model state, optimizer state, epoch, and loss so you can pick up exactly where you left off:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
| import torch
import torch.nn as nn
from torchvision.models import resnet18
model = resnet18(num_classes=10)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# Save a full checkpoint
checkpoint = {
"epoch": 5,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": 0.342,
"lr_scheduler_state": None, # add scheduler state if you use one
}
torch.save(checkpoint, "checkpoint_epoch_5.pt")
# Load it back
checkpoint = torch.load("checkpoint_epoch_5.pt", weights_only=False)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
start_epoch = checkpoint["epoch"] + 1
print(f"Resumed from epoch {start_epoch}, loss={checkpoint['loss']:.4f}")
|
That covers the basics. Now let’s build something you’d actually use in production.
Automatic Checkpointing with Rotation#
You don’t want to checkpoint every epoch and fill your disk. You also don’t want to keep only one checkpoint in case the latest one is corrupt. The standard approach: save every N steps and keep the last K checkpoints.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
| import os
import glob
import torch
import torch.nn as nn
from pathlib import Path
from torchvision.models import resnet18
class CheckpointManager:
"""Handles automatic checkpoint saving with rotation."""
def __init__(self, checkpoint_dir: str, max_keep: int = 3):
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.max_keep = max_keep
def save(self, model, optimizer, epoch, step, loss, **extra):
checkpoint = {
"epoch": epoch,
"global_step": step,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": loss,
**extra,
}
path = self.checkpoint_dir / f"ckpt_epoch{epoch}_step{step}.pt"
torch.save(checkpoint, path)
print(f"Saved checkpoint: {path}")
self._rotate()
def _rotate(self):
checkpoints = sorted(
self.checkpoint_dir.glob("ckpt_*.pt"),
key=lambda p: p.stat().st_mtime,
)
while len(checkpoints) > self.max_keep:
oldest = checkpoints.pop(0)
oldest.unlink()
print(f"Deleted old checkpoint: {oldest}")
def latest(self):
checkpoints = sorted(
self.checkpoint_dir.glob("ckpt_*.pt"),
key=lambda p: p.stat().st_mtime,
)
return checkpoints[-1] if checkpoints else None
# Usage in a training loop
model = resnet18(num_classes=10).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
ckpt_mgr = CheckpointManager("./checkpoints", max_keep=3)
# Dummy data for a complete example
dataset = torch.utils.data.TensorDataset(
torch.randn(1000, 3, 224, 224), torch.randint(0, 10, (1000,))
)
loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
save_every_steps = 100
global_step = 0
for epoch in range(20):
for batch_x, batch_y in loader:
batch_x, batch_y = batch_x.cuda(), batch_y.cuda()
optimizer.zero_grad()
loss = criterion(model(batch_x), batch_y)
loss.backward()
optimizer.step()
global_step += 1
if global_step % save_every_steps == 0:
ckpt_mgr.save(model, optimizer, epoch, global_step, loss.item())
|
The max_keep=3 parameter means you always have the three most recent checkpoints on disk. Old ones get deleted automatically. Adjust based on your model size and disk budget. A 7B model checkpoint is roughly 14 GB, so three copies eat 42 GB.
Resuming Training After a Crash#
The whole point of checkpoints is crash recovery. Your training script should check for existing checkpoints on startup and resume transparently.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
| def resume_training(model, optimizer, checkpoint_dir: str):
ckpt_mgr = CheckpointManager(checkpoint_dir)
latest_path = ckpt_mgr.latest()
if latest_path is None:
print("No checkpoint found. Starting from scratch.")
return 0, 0
checkpoint = torch.load(latest_path, weights_only=False, map_location="cuda")
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
start_epoch = checkpoint["epoch"]
global_step = checkpoint["global_step"]
print(f"Resumed from {latest_path} (epoch={start_epoch}, step={global_step})")
return start_epoch, global_step
# Wire it into the training loop
model = resnet18(num_classes=10).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
start_epoch, global_step = resume_training(model, optimizer, "./checkpoints")
for epoch in range(start_epoch, 20):
# training continues from where it left off
pass
|
One subtle point: map_location="cuda" ensures tensors load onto the right device. If you saved on GPU 0 but resume on a different machine, you might need map_location="cuda:0" or even "cpu" first.
Async Checkpoint Saving#
Saving a 14 GB checkpoint to disk takes a few seconds. Saving it to NFS or S3 takes longer. You don’t want your training loop blocked while bytes hit the wire. Use a background thread.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
| import threading
import torch
import io
import boto3
from pathlib import Path
from torchvision.models import resnet18
class AsyncCheckpointSaver:
"""Save checkpoints in a background thread to avoid blocking training."""
def __init__(self, checkpoint_dir: str, s3_bucket: str = None, s3_prefix: str = ""):
self.checkpoint_dir = Path(checkpoint_dir)
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
self.s3_bucket = s3_bucket
self.s3_prefix = s3_prefix
self._threads: list[threading.Thread] = []
def save_async(self, checkpoint: dict, filename: str):
# Serialize to bytes in the main thread (tensors must be read here)
buffer = io.BytesIO()
torch.save(checkpoint, buffer)
buffer.seek(0)
data = buffer.getvalue()
t = threading.Thread(target=self._write, args=(data, filename))
t.start()
self._threads.append(t)
def _write(self, data: bytes, filename: str):
# Write to local disk
local_path = self.checkpoint_dir / filename
with open(local_path, "wb") as f:
f.write(data)
print(f"Saved locally: {local_path}")
# Upload to S3 if configured
if self.s3_bucket:
s3 = boto3.client("s3")
s3_key = f"{self.s3_prefix}/{filename}" if self.s3_prefix else filename
s3.put_object(Bucket=self.s3_bucket, Key=s3_key, Body=data)
print(f"Uploaded to s3://{self.s3_bucket}/{s3_key}")
def wait(self):
"""Call at the end of training to ensure all saves complete."""
for t in self._threads:
t.join()
self._threads.clear()
# Example usage
saver = AsyncCheckpointSaver(
checkpoint_dir="./checkpoints",
s3_bucket="my-training-checkpoints",
s3_prefix="resnet18/run-001",
)
model = resnet18(num_classes=10).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# During training
checkpoint = {
"epoch": 10,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": 0.15,
}
saver.save_async(checkpoint, "ckpt_epoch10.pt")
# At the end of training
saver.wait()
|
The trick is serializing to bytes in the main thread with torch.save(checkpoint, buffer). The state dicts reference GPU tensors, and you need to read them before the next optimizer step potentially modifies them. The background thread only handles I/O, which is the slow part.
Multi-GPU Checkpointing with DistributedDataParallel#
When training with DDP, only rank 0 should save checkpoints. Every rank has an identical copy of the model, so saving from all ranks is wasteful and can cause file corruption from concurrent writes.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
| import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torchvision.models import resnet18
def setup_distributed():
dist.init_process_group(backend="nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
return local_rank
def save_distributed_checkpoint(model, optimizer, epoch, loss, path):
"""Save checkpoint only from rank 0."""
if dist.get_rank() == 0:
# model.module gets the unwrapped model inside DDP
checkpoint = {
"epoch": epoch,
"model_state_dict": model.module.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": loss,
}
torch.save(checkpoint, path)
print(f"[Rank 0] Saved checkpoint: {path}")
# All ranks wait until rank 0 finishes saving
dist.barrier()
def load_distributed_checkpoint(model, optimizer, path, device):
"""Load checkpoint on all ranks."""
map_location = {"cuda:0": f"cuda:{device}"}
checkpoint = torch.load(path, weights_only=False, map_location=map_location)
model.module.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
return checkpoint["epoch"]
# Training script (launch with torchrun --nproc_per_node=4 train.py)
local_rank = setup_distributed()
model = resnet18(num_classes=10).to(local_rank)
model = DDP(model, device_ids=[local_rank])
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(50):
# ... training loop ...
save_distributed_checkpoint(
model, optimizer, epoch, 0.25, f"checkpoints/ckpt_epoch{epoch}.pt"
)
dist.destroy_process_group()
|
Two things to watch for. First, always use model.module.state_dict() instead of model.state_dict() when the model is wrapped in DDP. The .module attribute gives you the underlying model without the DDP wrapper, which means the saved keys won’t have a module. prefix. This makes loading simpler and lets you use the checkpoint in single-GPU inference without any key remapping.
Second, the dist.barrier() call after saving is critical. Without it, rank 1 might try to load a checkpoint that rank 0 hasn’t finished writing yet.
Common Errors and Fixes#
RuntimeError: Error(s) in loading state_dict: Missing key(s) – You saved with DDP wrapping (model.state_dict()) but loaded without it. The keys have a module. prefix. Either save with model.module.state_dict() or strip prefixes on load:
1
2
3
| state_dict = checkpoint["model_state_dict"]
cleaned = {k.replace("module.", ""): v for k, v in state_dict.items()}
model.load_state_dict(cleaned)
|
CUDA out of memory when loading checkpoint – The checkpoint was saved on GPU and torch.load() tries to put tensors back on GPU. Load to CPU first:
1
2
3
| checkpoint = torch.load("ckpt.pt", weights_only=False, map_location="cpu")
model.load_state_dict(checkpoint["model_state_dict"])
model = model.cuda() # move after loading
|
Checkpoint file is corrupt / incomplete – Training was killed mid-save. Write to a temporary file and rename atomically:
1
2
3
4
5
| import tempfile
tmp_path = path + ".tmp"
torch.save(checkpoint, tmp_path)
os.replace(tmp_path, path) # atomic on POSIX systems
|
Optimizer state doesn't match after model changes – If you modify the model architecture between runs, the optimizer state dict won’t match. Skip loading the optimizer state and accept that learning rate warmup restarts:
1
2
3
| model.load_state_dict(checkpoint["model_state_dict"])
# Don't load optimizer -- architecture changed
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
S3 upload fails silently – boto3 doesn’t always raise on network errors with put_object. For large checkpoints, use multipart upload or verify with a HEAD request after upload:
1
2
3
| s3.put_object(Bucket=bucket, Key=key, Body=data)
resp = s3.head_object(Bucket=bucket, Key=key)
assert resp["ContentLength"] == len(data), "Upload size mismatch"
|