Multi-GPU on one machine gets you far. Multi-node across a cluster gets you to production scale. Lightning Fabric makes the jump from single-node to multi-node surprisingly painless – you change maybe ten lines and add a launch command. NCCL handles the GPU-to-GPU communication behind the scenes, and it’s the fastest backend available for NVIDIA hardware.
Here’s the minimal setup to get a Fabric script running across two nodes:
1
| pip install lightning torch torchvision
|
1
2
3
4
5
6
7
8
9
10
11
| from lightning.fabric import Fabric
fabric = Fabric(
accelerator="cuda",
devices=2, # GPUs per node
num_nodes=2, # total nodes in the cluster
strategy="ddp", # uses NCCL backend by default on CUDA
)
fabric.launch()
print(f"Global rank: {fabric.global_rank}, Local rank: {fabric.local_rank}, World size: {fabric.world_size}")
|
That’s the skeleton. Fabric’s DDP strategy defaults to the NCCL backend when you’re on CUDA devices. You don’t need to manually call torch.distributed.init_process_group – Fabric handles that internally. The num_nodes parameter tells Fabric how many machines participate, and devices is the GPU count per machine.
Full Multi-Node Training Script#
Here’s a complete, runnable training script that trains a ResNet-18 on CIFAR-10 across multiple nodes. Every variable is defined, every import is real.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
| import torch
import torch.nn as nn
from torch.utils.data import DataLoader, DistributedSampler
from torchvision import datasets, transforms, models
from lightning.fabric import Fabric
def get_dataloaders(fabric, batch_size=64):
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
train_dataset = datasets.CIFAR10(
root="./data", train=True, download=True, transform=transform
)
val_dataset = datasets.CIFAR10(
root="./data", train=False, download=True, transform=transform
)
# DistributedSampler shards data across all ranks automatically
train_sampler = DistributedSampler(
train_dataset,
num_replicas=fabric.world_size,
rank=fabric.global_rank,
shuffle=True,
)
val_sampler = DistributedSampler(
val_dataset,
num_replicas=fabric.world_size,
rank=fabric.global_rank,
shuffle=False,
)
train_loader = DataLoader(
train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=4, pin_memory=True
)
val_loader = DataLoader(
val_dataset, batch_size=batch_size, sampler=val_sampler, num_workers=4, pin_memory=True
)
return train_loader, val_loader, train_sampler
def train_one_epoch(fabric, model, optimizer, train_loader, train_sampler, epoch):
model.train()
train_sampler.set_epoch(epoch) # critical for proper shuffling each epoch
running_loss = 0.0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(train_loader):
optimizer.zero_grad()
outputs = model(inputs)
loss = nn.functional.cross_entropy(outputs, targets)
fabric.backward(loss)
optimizer.step()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
if batch_idx % 50 == 0 and fabric.global_rank == 0:
print(
f"Epoch {epoch} | Batch {batch_idx}/{len(train_loader)} | "
f"Loss: {running_loss / (batch_idx + 1):.4f} | "
f"Acc: {100.0 * correct / total:.2f}%"
)
return running_loss / len(train_loader)
@torch.no_grad()
def validate(fabric, model, val_loader):
model.eval()
correct = 0
total = 0
for inputs, targets in val_loader:
outputs = model(inputs)
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
# Gather accuracy across all ranks
correct_tensor = fabric.all_reduce(torch.tensor(correct, device=fabric.device), reduce_op="sum")
total_tensor = fabric.all_reduce(torch.tensor(total, device=fabric.device), reduce_op="sum")
accuracy = 100.0 * correct_tensor.item() / total_tensor.item()
return accuracy
def main():
fabric = Fabric(
accelerator="cuda",
devices=2,
num_nodes=2,
strategy="ddp",
precision="16-mixed",
)
fabric.launch()
# ResNet-18 adapted for CIFAR-10's 32x32 images
model = models.resnet18(num_classes=10)
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity() # skip maxpool for small images
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
model, optimizer = fabric.setup(model, optimizer)
train_loader, val_loader, train_sampler = get_dataloaders(fabric, batch_size=128)
train_loader, val_loader = fabric.setup_dataloaders(train_loader, val_loader)
for epoch in range(50):
train_loss = train_one_epoch(fabric, model, optimizer, train_loader, train_sampler, epoch)
val_acc = validate(fabric, model, val_loader)
scheduler.step()
if fabric.global_rank == 0:
print(f"Epoch {epoch} | Train Loss: {train_loss:.4f} | Val Acc: {val_acc:.2f}%")
if (epoch + 1) % 10 == 0:
state = {"model": model, "optimizer": optimizer, "epoch": epoch}
fabric.save("checkpoint.ckpt", state)
if fabric.global_rank == 0:
print("Training complete.")
if __name__ == "__main__":
main()
|
A few things worth calling out. The DistributedSampler splits the dataset evenly across all ranks (all GPUs across all nodes). Calling train_sampler.set_epoch(epoch) at the start of each epoch is mandatory – without it, every epoch uses the same shard order, which hurts convergence. The fabric.all_reduce call in validation aggregates counts across all ranks so you get global accuracy, not per-rank accuracy.
Launching Across Multiple Nodes#
You have two solid options for launching multi-node jobs. Pick the one that fits your infrastructure.
Option 1: fabric run (recommended for most setups)#
Run this on each node in your cluster:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
| # On node 0 (master)
fabric run \
--accelerator=cuda \
--devices=2 \
--num-nodes=2 \
--node-rank=0 \
--main-address=192.168.1.100 \
--main-port=29500 \
train.py
# On node 1
fabric run \
--accelerator=cuda \
--devices=2 \
--num-nodes=2 \
--node-rank=1 \
--main-address=192.168.1.100 \
--main-port=29500 \
train.py
|
The --main-address must point to node 0’s IP. Every node needs to reach that address on --main-port. NCCL also needs additional ports open for GPU-to-GPU traffic – typically a range above the main port.
Option 2: torchrun#
If you’re already familiar with torchrun, Fabric works with it directly. You don’t even need the fabric.launch() call – torchrun handles process spawning:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
| # On node 0
torchrun \
--nnodes=2 \
--nproc_per_node=2 \
--node_rank=0 \
--master_addr=192.168.1.100 \
--master_port=29500 \
train.py
# On node 1
torchrun \
--nnodes=2 \
--nproc_per_node=2 \
--node_rank=1 \
--master_addr=192.168.1.100 \
--master_port=29500 \
train.py
|
Option 3: Slurm (cluster environments)#
For managed clusters, wrap it in an sbatch script:
1
2
3
4
5
6
7
8
9
10
11
12
13
| #!/bin/bash
#SBATCH --job-name=fabric-multinode
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=2
#SBATCH --gpus-per-node=2
#SBATCH --cpus-per-task=4
#SBATCH --time=02:00:00
srun fabric run \
--accelerator=cuda \
--devices=2 \
--num-nodes=2 \
train.py
|
Slurm sets the node rank and master address environment variables automatically, so Fabric picks them up without explicit flags.
Monitoring and Logging Across Ranks#
When you’re running on 4+ GPUs across multiple machines, logging gets noisy fast. The key principle: only log from rank 0.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
| from lightning.fabric import Fabric
from lightning.fabric.loggers import TensorBoardLogger, CSVLogger
# Set up loggers -- only rank 0 actually writes
tb_logger = TensorBoardLogger(root_dir="logs", name="multinode_run")
csv_logger = CSVLogger(root_dir="logs", name="multinode_run")
fabric = Fabric(
accelerator="cuda",
devices=2,
num_nodes=2,
strategy="ddp",
loggers=[tb_logger, csv_logger],
)
fabric.launch()
# fabric.log only writes on rank 0 by default
fabric.log("train/loss", train_loss, step=global_step)
fabric.log("train/accuracy", train_acc, step=global_step)
fabric.log("val/accuracy", val_acc, step=global_step)
fabric.log("lr", optimizer.param_groups[0]["lr"], step=global_step)
|
fabric.log writes only on rank 0 by default, so you won’t get duplicate entries. For custom print statements, guard them:
1
2
| if fabric.global_rank == 0:
print(f"Step {step}: loss={loss.item():.4f}")
|
For monitoring NCCL health and GPU communication, set these environment variables before launching:
1
2
3
| export NCCL_DEBUG=INFO # logs NCCL initialization and topology
export NCCL_DEBUG_SUBSYS=ALL # full subsystem logging
export TORCH_DISTRIBUTED_DEBUG=DETAIL # PyTorch distributed diagnostics
|
Keep NCCL_DEBUG=INFO on during initial setup – it prints which network interfaces and protocols NCCL chose. Once everything works, switch to NCCL_DEBUG=WARN to reduce noise.
NCCL Configuration and Tuning#
NCCL picks reasonable defaults, but you can tune it for your network topology. These environment variables matter most:
1
2
3
4
5
6
7
8
9
10
11
| # Force NCCL to use a specific network interface (common in multi-NIC nodes)
export NCCL_SOCKET_IFNAME=eth0
# Use InfiniBand if available (much faster than TCP)
export NCCL_IB_DISABLE=0
# Set buffer sizes for large clusters (default 4MB, increase for high-bandwidth links)
export NCCL_BUFFSIZE=8388608
# Tree-based allreduce is faster for clusters with more than 8 GPUs
export NCCL_ALGO=Tree
|
If your nodes have both InfiniBand and Ethernet, NCCL should auto-detect IB. If it doesn’t, check that NCCL_IB_DISABLE is not set to 1 and that the IB devices are visible (ibstat should list them).
For cloud instances (AWS, GCP), you often need to explicitly set the network interface. AWS p4d instances use EFA, which requires NCCL_SOCKET_IFNAME=ens or similar.
Common Errors and Fixes#
RuntimeError: NCCL error: unhandled system error
This almost always means NCCL can’t establish connections between nodes. Check that:
- All nodes can reach the master address on the master port
- Firewall rules allow traffic on port 29500 and a range above it (NCCL uses additional ports)
- You’re using the correct network interface (
NCCL_SOCKET_IFNAME)
1
2
| # Test connectivity from node 1 to node 0
nc -zv 192.168.1.100 29500
|
RuntimeError: Timed out initializing process group
The default timeout is 30 minutes. If your nodes are slow to start (e.g., downloading datasets), increase it:
1
2
3
4
5
6
7
8
9
10
11
12
| import datetime
fabric = Fabric(
accelerator="cuda",
devices=2,
num_nodes=2,
strategy="ddp",
)
fabric.launch()
# Or set via environment variable before launch
# export NCCL_TIMEOUT=1800
|
You can also set TORCH_DIST_INIT_BARRIER=1 so all ranks wait until every process has joined before proceeding.
NCCL WARN Bootstrap: no socket interface found
NCCL can’t find a valid network interface. Explicitly set one:
1
2
3
4
5
| # List available interfaces
ip addr show
# Set the correct one
export NCCL_SOCKET_IFNAME=eth0
|
Validation accuracy much lower than single-node training
If you scale from 1 node to 2 nodes without adjusting anything, the effective batch size doubles. Either reduce the per-GPU batch size or scale the learning rate linearly:
1
2
3
4
| # Linear scaling rule: lr scales with world size
base_lr = 0.1
scaled_lr = base_lr * fabric.world_size
optimizer = torch.optim.SGD(model.parameters(), lr=scaled_lr, momentum=0.9, weight_decay=5e-4)
|
Alternatively, use learning rate warmup for the first few epochs to avoid divergence at high learning rates.
DataLoader worker exited unexpectedly
When using num_workers > 0 with distributed training, make sure the dataset download completes before spawning workers. Download data on rank 0 first, then barrier:
1
2
3
| if fabric.global_rank == 0:
datasets.CIFAR10(root="./data", train=True, download=True)
fabric.barrier() # all ranks wait here until rank 0 finishes downloading
|