Lightning Fabric sits between raw PyTorch and full PyTorch Lightning. You keep your training loop, but Fabric handles the ugly parts: device placement, distributed communication, mixed precision scaling. Five lines of code changes and your single-GPU script runs on a multi-GPU cluster.

Install both packages:

1
pip install lightning torch torchvision

Here’s the fastest way to see Fabric in action. Take any PyTorch training loop and wrap it:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
from lightning.fabric import Fabric

fabric = Fabric(accelerator="auto", devices="auto", precision="16-mixed")
fabric.launch()

model = MyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Fabric wraps model and optimizer for distributed + mixed precision
model, optimizer = fabric.setup(model, optimizer)
train_loader = fabric.setup_dataloaders(train_loader)

for batch in train_loader:
    optimizer.zero_grad()
    loss = model(batch)
    fabric.backward(loss)  # replaces loss.backward()
    optimizer.step()

That’s the pattern. Three calls: fabric.setup(), fabric.setup_dataloaders(), and fabric.backward(). Everything else stays the same.

Training a CNN on CIFAR-10

Here’s a complete, runnable example. We’ll train a small convolutional network on CIFAR-10 with mixed precision across however many GPUs you have available.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from lightning.fabric import Fabric


class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 4 * 4, 256)
        self.fc2 = nn.Linear(256, num_classes)
        self.dropout = nn.Dropout(0.25)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(x.size(0), -1)
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)
        return x


def main():
    fabric = Fabric(accelerator="auto", devices="auto", precision="16-mixed")
    fabric.launch()

    # Data
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])

    train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
    val_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=4)

    # Model and optimizer
    model = SimpleCNN(num_classes=10)
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-2)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)

    # Fabric setup
    model, optimizer = fabric.setup(model, optimizer)
    train_loader, val_loader = fabric.setup_dataloaders(train_loader, val_loader)

    # Training loop
    for epoch in range(20):
        model.train()
        total_loss = 0.0
        correct = 0
        total = 0

        for batch_idx, (inputs, targets) in enumerate(train_loader):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, targets)
            fabric.backward(loss)
            optimizer.step()

            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            if batch_idx % 100 == 0:
                fabric.print(
                    f"Epoch {epoch} | Batch {batch_idx} | "
                    f"Loss: {loss.item():.4f} | Acc: {100.0 * correct / total:.1f}%"
                )

        scheduler.step()

        # Validation
        model.eval()
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for inputs, targets in val_loader:
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()

        val_acc = 100.0 * val_correct / val_total
        fabric.print(f"Epoch {epoch} | Val Acc: {val_acc:.1f}%")

        # Save checkpoint
        state = {
            "model": model,
            "optimizer": optimizer,
            "epoch": epoch,
            "val_acc": val_acc,
        }
        fabric.save("checkpoint.ckpt", state)


if __name__ == "__main__":
    main()

Run it on a single GPU and it works. Run it on 4 GPUs and it still works, no code changes. Fabric handles the DistributedDataParallel wrapping, gradient synchronization, and data sharding behind the scenes.

Fabric Configuration Options

The Fabric constructor controls your hardware strategy. Here are the options you’ll actually use:

Accelerator and Devices

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
# Auto-detect everything
fabric = Fabric(accelerator="auto", devices="auto")

# Specific GPU count
fabric = Fabric(accelerator="gpu", devices=2)

# Specific GPU IDs
fabric = Fabric(accelerator="gpu", devices=[0, 2])

# CPU only
fabric = Fabric(accelerator="cpu")

Precision

Mixed precision saves memory and speeds up training on modern GPUs. The options that matter:

  • "32-true" – Full float32. Safe but slow.
  • "16-mixed" – Float16 mixed precision with autocast. Best for most workloads on NVIDIA Ampere+ GPUs.
  • "bf16-mixed" – BFloat16 mixed precision. More numerically stable than fp16, works well on A100 and newer.
1
2
# BFloat16 on A100 or H100
fabric = Fabric(accelerator="gpu", devices="auto", precision="bf16-mixed")

Distributed Strategy

For multi-node or advanced parallelism:

1
2
3
4
5
6
7
8
# DDP (default for multi-GPU)
fabric = Fabric(accelerator="gpu", devices=4, strategy="ddp")

# FSDP for large models that don't fit on a single GPU
fabric = Fabric(accelerator="gpu", devices=4, strategy="fsdp")

# DeepSpeed integration
fabric = Fabric(accelerator="gpu", devices=4, strategy="deepspeed_stage_2")

Checkpointing and Resuming

Fabric’s save and load handle distributed state correctly. When running with DDP, only rank 0 writes the file. When loading, each process gets the right slice.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
# Save (works in distributed — only rank 0 writes)
state = {
    "model": model,
    "optimizer": optimizer,
    "scheduler": scheduler,
    "epoch": epoch,
    "best_val_acc": best_val_acc,
}
fabric.save("checkpoints/epoch_{}.ckpt".format(epoch), state)

# Load and resume
state = {"model": model, "optimizer": optimizer, "scheduler": scheduler}
remainder = fabric.load("checkpoints/epoch_5.ckpt", state)
# remainder contains non-tensor state like {"epoch": 5, "best_val_acc": 92.3}
resumed_epoch = remainder["epoch"]
best_val_acc = remainder["best_val_acc"]

One thing to watch: call fabric.load() after fabric.setup(). Fabric needs the wrapped model and optimizer to map checkpoint state correctly.

Launching Multi-GPU Training

You have two options. The simplest is letting Fabric handle process spawning internally:

1
2
fabric = Fabric(accelerator="gpu", devices=4)
fabric.launch()  # spawns 4 processes

For more control, wrap your training in a function and pass it to fabric.launch():

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
from lightning.fabric import Fabric


def train(fabric):
    model = SimpleCNN()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    model, optimizer = fabric.setup(model, optimizer)
    # ... training loop


fabric = Fabric(accelerator="gpu", devices=4, precision="bf16-mixed")
fabric.launch(train)

You can also skip fabric.launch() entirely and use torchrun from the command line:

1
torchrun --nproc_per_node=4 train.py

When you use torchrun, Fabric detects the distributed environment automatically. No fabric.launch() call needed.

Logging Metrics

Fabric integrates with common loggers. Pass them at initialization:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
from lightning.fabric.loggers import TensorBoardLogger, CSVLogger

tb_logger = TensorBoardLogger(root_dir="logs", name="cifar10_run")
csv_logger = CSVLogger(root_dir="logs", name="cifar10_run")

fabric = Fabric(loggers=[tb_logger, csv_logger])
fabric.launch()

# Inside training loop
fabric.log("train/loss", loss.item(), step=global_step)
fabric.log("train/acc", accuracy, step=global_step)
fabric.log_dict(
    {"val/loss": val_loss, "val/acc": val_acc, "lr": scheduler.get_last_lr()[0]},
    step=global_step,
)

fabric.log is rank-aware. Only rank 0 writes to disk, so you don’t get duplicate entries from multiple GPUs.

Common Errors and Fixes

RuntimeError: Expected all tensors to be on the same device

This happens when you create tensors inside the training loop without placing them on the right device. Use fabric.device:

1
2
3
4
5
# Wrong
labels = torch.zeros(batch_size, dtype=torch.long)

# Right
labels = torch.zeros(batch_size, dtype=torch.long, device=fabric.device)

RuntimeError: Modules with uninitialized parameters can't be used with Fabric

Call fabric.setup(model, optimizer) after the model is fully initialized. Lazy modules need a forward pass first.

ValueError: precision='16-mixed' requires a GPU

Mixed precision with float16 needs a CUDA device. On CPU, use "bf16-mixed" or "32-true":

1
2
# For CPU-only machines
fabric = Fabric(accelerator="cpu", precision="bf16-mixed")

Checkpoints are huge / saving is slow

If you’re using FSDP, set state_dict_type to avoid gathering the full model on rank 0:

1
2
3
4
5
6
7
fabric = Fabric(
    strategy="fsdp",
    accelerator="gpu",
    devices=4,
)
# Use sharded checkpoints
fabric.save("checkpoint.ckpt", state)

fabric.backward(loss) hangs in multi-GPU

Usually means one process hit an error while others are waiting at a synchronization point. Check all processes’ logs. A common cause is uneven batch sizes at epoch end. Fix by dropping the last incomplete batch:

1
train_loader = DataLoader(dataset, batch_size=128, drop_last=True)