Why Multi-GPU Training#
Single-GPU training hits a wall fast. A 7B parameter model takes 24+ hours to fine-tune on one A100. With 4 GPUs, that drops to ~6 hours. With 8 GPUs, ~3 hours. The math is straightforward: more GPUs = proportionally faster training.
PyTorch’s DistributedDataParallel (DDP) is the standard approach. It splits your data across GPUs, runs forward/backward passes in parallel, and synchronizes gradients. Your model code barely changes.
Single-GPU Baseline#
Here’s a standard training loop we’ll convert to multi-GPU.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
| import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
# Simple model for demonstration
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10),
)
# Dummy dataset (replace with your real data)
X = torch.randn(10000, 784)
y = torch.randint(0, 10, (10000,))
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=64, shuffle=True)
model = model.to("cuda")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
for epoch in range(10):
total_loss = 0
for batch_x, batch_y in loader:
batch_x, batch_y = batch_x.to("cuda"), batch_y.to("cuda")
optimizer.zero_grad()
output = model(batch_x)
loss = criterion(output, batch_y)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch}: loss={total_loss/len(loader):.4f}")
|
Convert to Multi-GPU with DDP#
The changes are minimal. You add initialization, wrap the model, and use a distributed sampler.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
| import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, TensorDataset, DistributedSampler
def setup(rank, world_size):
"""Initialize the distributed process group."""
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
"""Destroy the process group."""
dist.destroy_process_group()
def train(rank, world_size):
"""Training function that runs on each GPU."""
setup(rank, world_size)
# Model — same as before, but wrapped in DDP
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10),
).to(rank)
model = DDP(model, device_ids=[rank])
# Dataset with DistributedSampler
X = torch.randn(10000, 784)
y = torch.randint(0, 10, (10000,))
dataset = TensorDataset(X, y)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
loader = DataLoader(dataset, batch_size=64, sampler=sampler)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
for epoch in range(10):
sampler.set_epoch(epoch) # Shuffle differently each epoch
total_loss = 0
for batch_x, batch_y in loader:
batch_x, batch_y = batch_x.to(rank), batch_y.to(rank)
optimizer.zero_grad()
output = model(batch_x)
loss = criterion(output, batch_y)
loss.backward()
optimizer.step()
total_loss += loss.item()
if rank == 0: # Only print from the main process
print(f"Epoch {epoch}: loss={total_loss/len(loader):.4f}")
cleanup()
|
Launch the Training#
DDP uses one process per GPU. Launch with torch.multiprocessing or torchrun.
1
2
3
4
5
6
7
8
9
| import torch.multiprocessing as mp
def main():
world_size = torch.cuda.device_count()
print(f"Training on {world_size} GPUs")
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
if __name__ == "__main__":
main()
|
Or use torchrun from the command line (preferred for production):
1
2
| # Automatically handles process spawning and environment variables
torchrun --nproc_per_node=4 train.py
|
torchrun is better than mp.spawn because it handles fault tolerance and elastic training. If a GPU fails, it can restart the process automatically.
Key Differences from Single-GPU#
| Concept | Single GPU | Multi-GPU (DDP) |
|---|
| Model | model.to("cuda") | DDP(model.to(rank), device_ids=[rank]) |
| DataLoader | shuffle=True | sampler=DistributedSampler(...) |
| Device | "cuda" | rank (GPU index) |
| Logging | Always print | if rank == 0: |
| Saving | torch.save(model.state_dict()) | torch.save(model.module.state_dict()) |
Save and Load Checkpoints#
Save only from rank 0 to avoid file corruption from concurrent writes.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
| def save_checkpoint(model, optimizer, epoch, path):
"""Save a training checkpoint (call only from rank 0)."""
torch.save({
"epoch": epoch,
"model_state_dict": model.module.state_dict(), # .module unwraps DDP
"optimizer_state_dict": optimizer.state_dict(),
}, path)
print(f"Checkpoint saved: {path}")
def load_checkpoint(model, optimizer, path, rank):
"""Load a checkpoint onto the correct device."""
map_location = {"cuda:0": f"cuda:{rank}"}
checkpoint = torch.load(path, map_location=map_location)
model.module.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
return checkpoint["epoch"]
# In your training loop:
if rank == 0 and epoch % 5 == 0:
save_checkpoint(model, optimizer, epoch, f"checkpoint_epoch_{epoch}.pt")
|
Scaling the Learning Rate#
When you use N GPUs, each GPU processes batch_size samples. The effective batch size is N * batch_size. Scale the learning rate linearly.
1
2
3
4
5
| world_size = torch.cuda.device_count()
base_lr = 1e-3
scaled_lr = base_lr * world_size # Linear scaling rule
optimizer = torch.optim.Adam(model.parameters(), lr=scaled_lr)
|
For large scaling factors (8+ GPUs), use learning rate warmup to avoid training instability:
1
2
3
4
5
| from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR
warmup = LinearLR(optimizer, start_factor=0.1, total_iters=100)
cosine = CosineAnnealingLR(optimizer, T_max=1000)
scheduler = SequentialLR(optimizer, [warmup, cosine], milestones=[100])
|
Common Issues#
NCCL timeout errors. Usually means one GPU is slower or has a different CUDA version. Check that all GPUs are the same model and have the same driver version with nvidia-smi.
Hanging at init_process_group. All processes must call init_process_group before any can proceed. Make sure your code doesn’t have conditional paths that skip initialization on some ranks.
Different results vs single GPU. Expected — the random sampling order differs. Set sampler.set_epoch(epoch) every epoch for proper shuffling. Results should converge to the same quality.
OOM on one GPU but not others. GPU 0 often uses more memory because it handles gradient reduction by default. Use --nproc_per_node to match your actual GPU count, and set CUDA_VISIBLE_DEVICES if needed.
1
2
| # Use only GPUs 0,1,2,3
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 train.py
|