Why Multi-GPU Training

Single-GPU training hits a wall fast. A 7B parameter model takes 24+ hours to fine-tune on one A100. With 4 GPUs, that drops to ~6 hours. With 8 GPUs, ~3 hours. The math is straightforward: more GPUs = proportionally faster training.

PyTorch’s DistributedDataParallel (DDP) is the standard approach. It splits your data across GPUs, runs forward/backward passes in parallel, and synchronizes gradients. Your model code barely changes.

Single-GPU Baseline

Here’s a standard training loop we’ll convert to multi-GPU.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

# Simple model for demonstration
model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 10),
)

# Dummy dataset (replace with your real data)
X = torch.randn(10000, 784)
y = torch.randint(0, 10, (10000,))
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=64, shuffle=True)

model = model.to("cuda")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

for epoch in range(10):
    total_loss = 0
    for batch_x, batch_y in loader:
        batch_x, batch_y = batch_x.to("cuda"), batch_y.to("cuda")
        optimizer.zero_grad()
        output = model(batch_x)
        loss = criterion(output, batch_y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch}: loss={total_loss/len(loader):.4f}")

Convert to Multi-GPU with DDP

The changes are minimal. You add initialization, wrap the model, and use a distributed sampler.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, TensorDataset, DistributedSampler


def setup(rank, world_size):
    """Initialize the distributed process group."""
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def cleanup():
    """Destroy the process group."""
    dist.destroy_process_group()


def train(rank, world_size):
    """Training function that runs on each GPU."""
    setup(rank, world_size)

    # Model — same as before, but wrapped in DDP
    model = nn.Sequential(
        nn.Linear(784, 256),
        nn.ReLU(),
        nn.Linear(256, 10),
    ).to(rank)
    model = DDP(model, device_ids=[rank])

    # Dataset with DistributedSampler
    X = torch.randn(10000, 784)
    y = torch.randint(0, 10, (10000,))
    dataset = TensorDataset(X, y)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    loader = DataLoader(dataset, batch_size=64, sampler=sampler)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(10):
        sampler.set_epoch(epoch)  # Shuffle differently each epoch
        total_loss = 0
        for batch_x, batch_y in loader:
            batch_x, batch_y = batch_x.to(rank), batch_y.to(rank)
            optimizer.zero_grad()
            output = model(batch_x)
            loss = criterion(output, batch_y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        if rank == 0:  # Only print from the main process
            print(f"Epoch {epoch}: loss={total_loss/len(loader):.4f}")

    cleanup()

Launch the Training

DDP uses one process per GPU. Launch with torch.multiprocessing or torchrun.

1
2
3
4
5
6
7
8
9
import torch.multiprocessing as mp

def main():
    world_size = torch.cuda.device_count()
    print(f"Training on {world_size} GPUs")
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)

if __name__ == "__main__":
    main()

Or use torchrun from the command line (preferred for production):

1
2
# Automatically handles process spawning and environment variables
torchrun --nproc_per_node=4 train.py

torchrun is better than mp.spawn because it handles fault tolerance and elastic training. If a GPU fails, it can restart the process automatically.

Key Differences from Single-GPU

ConceptSingle GPUMulti-GPU (DDP)
Modelmodel.to("cuda")DDP(model.to(rank), device_ids=[rank])
DataLoadershuffle=Truesampler=DistributedSampler(...)
Device"cuda"rank (GPU index)
LoggingAlways printif rank == 0:
Savingtorch.save(model.state_dict())torch.save(model.module.state_dict())

Save and Load Checkpoints

Save only from rank 0 to avoid file corruption from concurrent writes.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
def save_checkpoint(model, optimizer, epoch, path):
    """Save a training checkpoint (call only from rank 0)."""
    torch.save({
        "epoch": epoch,
        "model_state_dict": model.module.state_dict(),  # .module unwraps DDP
        "optimizer_state_dict": optimizer.state_dict(),
    }, path)
    print(f"Checkpoint saved: {path}")


def load_checkpoint(model, optimizer, path, rank):
    """Load a checkpoint onto the correct device."""
    map_location = {"cuda:0": f"cuda:{rank}"}
    checkpoint = torch.load(path, map_location=map_location)
    model.module.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    return checkpoint["epoch"]


# In your training loop:
if rank == 0 and epoch % 5 == 0:
    save_checkpoint(model, optimizer, epoch, f"checkpoint_epoch_{epoch}.pt")

Scaling the Learning Rate

When you use N GPUs, each GPU processes batch_size samples. The effective batch size is N * batch_size. Scale the learning rate linearly.

1
2
3
4
5
world_size = torch.cuda.device_count()
base_lr = 1e-3
scaled_lr = base_lr * world_size  # Linear scaling rule

optimizer = torch.optim.Adam(model.parameters(), lr=scaled_lr)

For large scaling factors (8+ GPUs), use learning rate warmup to avoid training instability:

1
2
3
4
5
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR

warmup = LinearLR(optimizer, start_factor=0.1, total_iters=100)
cosine = CosineAnnealingLR(optimizer, T_max=1000)
scheduler = SequentialLR(optimizer, [warmup, cosine], milestones=[100])

Common Issues

NCCL timeout errors. Usually means one GPU is slower or has a different CUDA version. Check that all GPUs are the same model and have the same driver version with nvidia-smi.

Hanging at init_process_group. All processes must call init_process_group before any can proceed. Make sure your code doesn’t have conditional paths that skip initialization on some ranks.

Different results vs single GPU. Expected — the random sampling order differs. Set sampler.set_epoch(epoch) every epoch for proper shuffling. Results should converge to the same quality.

OOM on one GPU but not others. GPU 0 often uses more memory because it handles gradient reduction by default. Use --nproc_per_node to match your actual GPU count, and set CUDA_VISIBLE_DEVICES if needed.

1
2
# Use only GPUs 0,1,2,3
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 train.py