Lightning Fabric sits between raw PyTorch and full PyTorch Lightning. You keep your training loop, but Fabric handles the ugly parts: device placement, distributed communication, mixed precision scaling. Five lines of code changes and your single-GPU script runs on a multi-GPU cluster.
Install both packages:
1
| pip install lightning torch torchvision
|
Here’s the fastest way to see Fabric in action. Take any PyTorch training loop and wrap it:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
| from lightning.fabric import Fabric
fabric = Fabric(accelerator="auto", devices="auto", precision="16-mixed")
fabric.launch()
model = MyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Fabric wraps model and optimizer for distributed + mixed precision
model, optimizer = fabric.setup(model, optimizer)
train_loader = fabric.setup_dataloaders(train_loader)
for batch in train_loader:
optimizer.zero_grad()
loss = model(batch)
fabric.backward(loss) # replaces loss.backward()
optimizer.step()
|
That’s the pattern. Three calls: fabric.setup(), fabric.setup_dataloaders(), and fabric.backward(). Everything else stays the same.
Training a CNN on CIFAR-10#
Here’s a complete, runnable example. We’ll train a small convolutional network on CIFAR-10 with mixed precision across however many GPUs you have available.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
| import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from lightning.fabric import Fabric
class SimpleCNN(nn.Module):
def __init__(self, num_classes=10):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(128 * 4 * 4, 256)
self.fc2 = nn.Linear(256, num_classes)
self.dropout = nn.Dropout(0.25)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = x.view(x.size(0), -1)
x = self.dropout(F.relu(self.fc1(x)))
x = self.fc2(x)
return x
def main():
fabric = Fabric(accelerator="auto", devices="auto", precision="16-mixed")
fabric.launch()
# Data
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
val_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=4)
# Model and optimizer
model = SimpleCNN(num_classes=10)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
# Fabric setup
model, optimizer = fabric.setup(model, optimizer)
train_loader, val_loader = fabric.setup_dataloaders(train_loader, val_loader)
# Training loop
for epoch in range(20):
model.train()
total_loss = 0.0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(train_loader):
optimizer.zero_grad()
outputs = model(inputs)
loss = F.cross_entropy(outputs, targets)
fabric.backward(loss)
optimizer.step()
total_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
if batch_idx % 100 == 0:
fabric.print(
f"Epoch {epoch} | Batch {batch_idx} | "
f"Loss: {loss.item():.4f} | Acc: {100.0 * correct / total:.1f}%"
)
scheduler.step()
# Validation
model.eval()
val_correct = 0
val_total = 0
with torch.no_grad():
for inputs, targets in val_loader:
outputs = model(inputs)
_, predicted = outputs.max(1)
val_total += targets.size(0)
val_correct += predicted.eq(targets).sum().item()
val_acc = 100.0 * val_correct / val_total
fabric.print(f"Epoch {epoch} | Val Acc: {val_acc:.1f}%")
# Save checkpoint
state = {
"model": model,
"optimizer": optimizer,
"epoch": epoch,
"val_acc": val_acc,
}
fabric.save("checkpoint.ckpt", state)
if __name__ == "__main__":
main()
|
Run it on a single GPU and it works. Run it on 4 GPUs and it still works, no code changes. Fabric handles the DistributedDataParallel wrapping, gradient synchronization, and data sharding behind the scenes.
Fabric Configuration Options#
The Fabric constructor controls your hardware strategy. Here are the options you’ll actually use:
Accelerator and Devices#
1
2
3
4
5
6
7
8
9
10
11
| # Auto-detect everything
fabric = Fabric(accelerator="auto", devices="auto")
# Specific GPU count
fabric = Fabric(accelerator="gpu", devices=2)
# Specific GPU IDs
fabric = Fabric(accelerator="gpu", devices=[0, 2])
# CPU only
fabric = Fabric(accelerator="cpu")
|
Precision#
Mixed precision saves memory and speeds up training on modern GPUs. The options that matter:
"32-true" – Full float32. Safe but slow."16-mixed" – Float16 mixed precision with autocast. Best for most workloads on NVIDIA Ampere+ GPUs."bf16-mixed" – BFloat16 mixed precision. More numerically stable than fp16, works well on A100 and newer.
1
2
| # BFloat16 on A100 or H100
fabric = Fabric(accelerator="gpu", devices="auto", precision="bf16-mixed")
|
Distributed Strategy#
For multi-node or advanced parallelism:
1
2
3
4
5
6
7
8
| # DDP (default for multi-GPU)
fabric = Fabric(accelerator="gpu", devices=4, strategy="ddp")
# FSDP for large models that don't fit on a single GPU
fabric = Fabric(accelerator="gpu", devices=4, strategy="fsdp")
# DeepSpeed integration
fabric = Fabric(accelerator="gpu", devices=4, strategy="deepspeed_stage_2")
|
Checkpointing and Resuming#
Fabric’s save and load handle distributed state correctly. When running with DDP, only rank 0 writes the file. When loading, each process gets the right slice.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
| # Save (works in distributed — only rank 0 writes)
state = {
"model": model,
"optimizer": optimizer,
"scheduler": scheduler,
"epoch": epoch,
"best_val_acc": best_val_acc,
}
fabric.save("checkpoints/epoch_{}.ckpt".format(epoch), state)
# Load and resume
state = {"model": model, "optimizer": optimizer, "scheduler": scheduler}
remainder = fabric.load("checkpoints/epoch_5.ckpt", state)
# remainder contains non-tensor state like {"epoch": 5, "best_val_acc": 92.3}
resumed_epoch = remainder["epoch"]
best_val_acc = remainder["best_val_acc"]
|
One thing to watch: call fabric.load() after fabric.setup(). Fabric needs the wrapped model and optimizer to map checkpoint state correctly.
Launching Multi-GPU Training#
You have two options. The simplest is letting Fabric handle process spawning internally:
1
2
| fabric = Fabric(accelerator="gpu", devices=4)
fabric.launch() # spawns 4 processes
|
For more control, wrap your training in a function and pass it to fabric.launch():
1
2
3
4
5
6
7
8
9
10
11
12
| from lightning.fabric import Fabric
def train(fabric):
model = SimpleCNN()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
model, optimizer = fabric.setup(model, optimizer)
# ... training loop
fabric = Fabric(accelerator="gpu", devices=4, precision="bf16-mixed")
fabric.launch(train)
|
You can also skip fabric.launch() entirely and use torchrun from the command line:
1
| torchrun --nproc_per_node=4 train.py
|
When you use torchrun, Fabric detects the distributed environment automatically. No fabric.launch() call needed.
Logging Metrics#
Fabric integrates with common loggers. Pass them at initialization:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
| from lightning.fabric.loggers import TensorBoardLogger, CSVLogger
tb_logger = TensorBoardLogger(root_dir="logs", name="cifar10_run")
csv_logger = CSVLogger(root_dir="logs", name="cifar10_run")
fabric = Fabric(loggers=[tb_logger, csv_logger])
fabric.launch()
# Inside training loop
fabric.log("train/loss", loss.item(), step=global_step)
fabric.log("train/acc", accuracy, step=global_step)
fabric.log_dict(
{"val/loss": val_loss, "val/acc": val_acc, "lr": scheduler.get_last_lr()[0]},
step=global_step,
)
|
fabric.log is rank-aware. Only rank 0 writes to disk, so you don’t get duplicate entries from multiple GPUs.
Common Errors and Fixes#
RuntimeError: Expected all tensors to be on the same device
This happens when you create tensors inside the training loop without placing them on the right device. Use fabric.device:
1
2
3
4
5
| # Wrong
labels = torch.zeros(batch_size, dtype=torch.long)
# Right
labels = torch.zeros(batch_size, dtype=torch.long, device=fabric.device)
|
RuntimeError: Modules with uninitialized parameters can't be used with Fabric
Call fabric.setup(model, optimizer) after the model is fully initialized. Lazy modules need a forward pass first.
ValueError: precision='16-mixed' requires a GPU
Mixed precision with float16 needs a CUDA device. On CPU, use "bf16-mixed" or "32-true":
1
2
| # For CPU-only machines
fabric = Fabric(accelerator="cpu", precision="bf16-mixed")
|
Checkpoints are huge / saving is slow
If you’re using FSDP, set state_dict_type to avoid gathering the full model on rank 0:
1
2
3
4
5
6
7
| fabric = Fabric(
strategy="fsdp",
accelerator="gpu",
devices=4,
)
# Use sharded checkpoints
fabric.save("checkpoint.ckpt", state)
|
fabric.backward(loss) hangs in multi-GPU
Usually means one process hit an error while others are waiting at a synchronization point. Check all processes’ logs. A common cause is uneven batch sizes at epoch end. Fix by dropping the last incomplete batch:
1
| train_loader = DataLoader(dataset, batch_size=128, drop_last=True)
|