MosaicML Composer wraps PyTorch’s Fully Sharded Data Parallel (FSDP) with a clean Trainer interface and pluggable speed-up algorithms. You get gradient accumulation, mixed precision, checkpoint management, and logging without wiring any of it yourself. The payoff: you write a training config, point it at a HuggingFace model, and Composer handles the distributed orchestration.

Install the stack:

1
pip install mosaicml composer torch torchvision transformers mosaicml-streaming

Verify FSDP support is available (requires PyTorch 2.0+):

1
2
3
import torch
print(torch.__version__)           # 2.2.0 or later
print(torch.cuda.device_count())   # should be > 1 for FSDP

Setting Up a Composer Trainer with FSDP

Composer’s Trainer takes a fsdp_config dictionary that maps directly to PyTorch FSDP options. Here’s a minimal setup that shards a BERT model across all available GPUs:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
from composer import Trainer
from composer.models import HuggingFaceModel
from composer.optim import DecoupledAdamW, LinearWithWarmupScheduler
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_dataset

# Load a HuggingFace model and wrap it for Composer
hf_model = AutoModelForSequenceClassification.from_pretrained(
    "bert-base-uncased", num_labels=2
)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
composer_model = HuggingFaceModel(hf_model, tokenizer=tokenizer, use_logits=True)

# Prepare a dataset
dataset = load_dataset("glue", "sst2", split="train")

def tokenize_fn(examples):
    return tokenizer(examples["sentence"], truncation=True, padding="max_length", max_length=128)

dataset = dataset.map(tokenize_fn, batched=True, remove_columns=["sentence", "idx"])
dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
train_loader = DataLoader(dataset, batch_size=16, shuffle=True)

# FSDP config
fsdp_config = {
    "sharding_strategy": "FULL_SHARD",
    "cpu_offload": False,
    "mixed_precision": "PURE",
    "backward_prefetch": "BACKWARD_PRE",
    "activation_checkpointing": True,
    "activation_cpu_offload": False,
    "limit_all_gathers": True,
    "verbose": False,
}

optimizer = DecoupledAdamW(composer_model.parameters(), lr=2e-5, weight_decay=0.01)
scheduler = LinearWithWarmupScheduler(t_warmup="0.1dur", alpha_f=0.0)

trainer = Trainer(
    model=composer_model,
    train_dataloader=train_loader,
    optimizers=optimizer,
    schedulers=scheduler,
    max_duration="3ep",
    device_train_microbatch_size=4,
    precision="amp_bf16",
    fsdp_config=fsdp_config,
    seed=42,
)

trainer.fit()

Key things happening here:

  • sharding_strategy: FULL_SHARD partitions parameters, gradients, and optimizer states across GPUs. This is equivalent to DeepSpeed ZeRO-3.
  • device_train_microbatch_size controls the per-GPU micro batch. Composer handles gradient accumulation automatically if this is smaller than the DataLoader batch size.
  • precision: amp_bf16 uses bf16 automatic mixed precision. On Ampere GPUs or newer, this is faster and more stable than fp16.
  • activation_checkpointing: True recomputes activations during backward to save memory, at the cost of roughly 30% more compute.

FSDP Sharding Strategies

Composer exposes three sharding strategies through the fsdp_config:

FULL_SHARD

Parameters, gradients, and optimizer states are all sharded. Maximum memory savings, highest communication overhead. Use this when your model doesn’t fit on a single GPU even with mixed precision.

1
2
3
4
5
6
7
8
9
fsdp_config = {
    "sharding_strategy": "FULL_SHARD",
    "cpu_offload": False,
    "mixed_precision": "PURE",
    "backward_prefetch": "BACKWARD_PRE",
    "activation_checkpointing": True,
    "limit_all_gathers": True,
    "verbose": False,
}

SHARD_GRAD_OP

Only gradients and optimizer states are sharded. Parameters stay replicated. Less communication than FULL_SHARD, but uses more memory. This is equivalent to DeepSpeed ZeRO-2 and works well when the model fits on each GPU but optimizer memory is the bottleneck.

1
2
3
4
5
6
7
8
9
fsdp_config = {
    "sharding_strategy": "SHARD_GRAD_OP",
    "cpu_offload": False,
    "mixed_precision": "PURE",
    "backward_prefetch": "BACKWARD_PRE",
    "activation_checkpointing": False,
    "limit_all_gathers": True,
    "verbose": False,
}

NO_SHARD

No sharding – standard DDP behavior. Useful for debugging or when you have plenty of GPU memory and want maximum throughput without any FSDP overhead.

Pick your strategy based on this rule of thumb: start with SHARD_GRAD_OP. If you OOM, switch to FULL_SHARD. If you still OOM, add cpu_offload: True and activation_checkpointing: True.

Speed-Up Algorithms

Composer ships with a library of algorithms that speed up training or reduce memory usage. You plug them into the Trainer as a list:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
from composer.algorithms import GradientClipping, LayerFreezing, FusedLayerNorm

algorithms = [
    GradientClipping(clipping_type="norm", clipping_threshold=1.0),
    FusedLayerNorm(),
]

trainer = Trainer(
    model=composer_model,
    train_dataloader=train_loader,
    optimizers=optimizer,
    schedulers=scheduler,
    max_duration="3ep",
    device_train_microbatch_size=4,
    precision="amp_bf16",
    fsdp_config=fsdp_config,
    algorithms=algorithms,
    seed=42,
)

Notable algorithms worth trying:

  • GradientClipping – prevents exploding gradients. Essential for large models.
  • FusedLayerNorm – replaces standard LayerNorm with a fused CUDA kernel. Free speed on NVIDIA GPUs.
  • LayerFreezing – progressively freezes early layers during training. Reduces compute as training progresses.

Callbacks for Logging, Checkpointing, and Early Stopping

Composer’s callback system hooks into every stage of the training loop. Here’s how to set up the three you’ll always want:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from composer.callbacks import (
    SpeedMonitor,
    LRMonitor,
    MemoryMonitor,
    EarlyStopper,
)
from composer.loggers import InMemoryLogger, FileLogger

# Loggers
file_logger = FileLogger(filename="training_log.txt", flush_interval=50)
in_mem_logger = InMemoryLogger()

# Callbacks
speed_monitor = SpeedMonitor(window_size=50)
lr_monitor = LRMonitor()
memory_monitor = MemoryMonitor()
early_stopper = EarlyStopper(monitor="metrics/train/CrossEntropy", patience="2ep")

trainer = Trainer(
    model=composer_model,
    train_dataloader=train_loader,
    optimizers=optimizer,
    schedulers=scheduler,
    max_duration="10ep",
    device_train_microbatch_size=4,
    precision="amp_bf16",
    fsdp_config=fsdp_config,
    loggers=[file_logger, in_mem_logger],
    callbacks=[speed_monitor, lr_monitor, memory_monitor, early_stopper],
    save_folder="./checkpoints",
    save_interval="1ep",
    save_num_checkpoints_to_keep=3,
    seed=42,
)

trainer.fit()

What each callback does:

  • SpeedMonitor logs throughput (samples/sec, tokens/sec, FLOPS) – critical for spotting regressions.
  • LRMonitor logs the learning rate schedule so you can verify warmup and decay are working.
  • MemoryMonitor tracks GPU memory usage. Helps you decide if you can increase micro batch size.
  • EarlyStopper halts training if the monitored metric doesn’t improve for patience epochs.

Checkpointing is built into the Trainer itself. The save_folder, save_interval, and save_num_checkpoints_to_keep parameters handle rotation automatically. Composer saves FSDP-aware checkpoints that you can resume from without manually gathering sharded weights.

Launching Multi-GPU Training

Composer includes its own launcher. Save your training script as train.py and run:

1
2
3
4
5
# 4 GPUs on a single node
composer -n 4 train.py

# 8 GPUs across 2 nodes
composer -n 8 --world_size 8 --node_rank 0 --master_addr 10.0.0.1 --master_port 7501 train.py

Alternatively, use torchrun if you prefer the standard PyTorch launcher:

1
torchrun --nproc_per_node=4 train.py

Both work. Composer’s launcher sets some environment variables that its Trainer picks up automatically, so it’s slightly less config to deal with.

For multi-node runs, make sure:

  • All nodes can reach the master address on the specified port.
  • NCCL is using the right network interface. Set NCCL_SOCKET_IFNAME if you have multiple NICs.
  • The same code and data paths exist on every node.

Profiling Training Performance

Composer integrates with PyTorch’s profiler through a callback:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from composer.callbacks import RuntimeEstimator
from torch.profiler import schedule, tensorboard_trace_handler

runtime_estimator = RuntimeEstimator()

# PyTorch profiler via Composer's profiler interface
from composer.profiler import Profiler, JSONTraceHandler, cyclic_schedule

profiler = Profiler(
    trace_handlers=[JSONTraceHandler(folder="./profiler_traces")],
    schedule=cyclic_schedule(
        skip_first=1,
        wait=0,
        warmup=1,
        active=4,
        repeat=1,
    ),
    torch_prof_memory_filename="memory_timeline.html",
)

trainer = Trainer(
    model=composer_model,
    train_dataloader=train_loader,
    optimizers=optimizer,
    schedulers=scheduler,
    max_duration="1ep",
    device_train_microbatch_size=4,
    precision="amp_bf16",
    fsdp_config=fsdp_config,
    callbacks=[runtime_estimator],
    profiler=profiler,
    seed=42,
)

trainer.fit()

After training, view traces in Chrome’s chrome://tracing or TensorBoard. Look for:

  • Large gaps between GPU kernels – indicates CPU bottleneck or slow data loading.
  • All-reduce taking more than 20% of step time – your communication overhead is too high. Try SHARD_GRAD_OP instead of FULL_SHARD, or increase micro batch size.
  • Memory spikes during forward pass – activation checkpointing isn’t covering all layers, or your batch is too large.

Complete Training Script

Here’s a full script that ties everything together – FSDP sharding, speed-up algorithms, callbacks, checkpointing, and logging:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#!/usr/bin/env python3
"""Train BERT for sequence classification with Composer + FSDP."""
import torch
from composer import Trainer
from composer.algorithms import GradientClipping, FusedLayerNorm
from composer.callbacks import SpeedMonitor, LRMonitor, MemoryMonitor
from composer.loggers import FileLogger
from composer.models import HuggingFaceModel
from composer.optim import DecoupledAdamW, LinearWithWarmupScheduler
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer


def main():
    # Model
    model_name = "bert-base-uncased"
    hf_model = AutoModelForSequenceClassification.from_pretrained(
        model_name, num_labels=2
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    composer_model = HuggingFaceModel(hf_model, tokenizer=tokenizer, use_logits=True)

    # Data
    dataset = load_dataset("glue", "sst2", split="train")

    def tokenize_fn(examples):
        return tokenizer(
            examples["sentence"], truncation=True, padding="max_length", max_length=128
        )

    dataset = dataset.map(tokenize_fn, batched=True, remove_columns=["sentence", "idx"])
    dataset.set_format("torch", columns=["input_ids", "attention_mask", "label"])
    train_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

    # Optimizer and scheduler
    optimizer = DecoupledAdamW(composer_model.parameters(), lr=2e-5, weight_decay=0.01)
    scheduler = LinearWithWarmupScheduler(t_warmup="0.06dur", alpha_f=0.0)

    # FSDP
    fsdp_config = {
        "sharding_strategy": "FULL_SHARD",
        "cpu_offload": False,
        "mixed_precision": "PURE",
        "backward_prefetch": "BACKWARD_PRE",
        "activation_checkpointing": True,
        "activation_cpu_offload": False,
        "limit_all_gathers": True,
        "verbose": False,
    }

    # Algorithms
    algorithms = [
        GradientClipping(clipping_type="norm", clipping_threshold=1.0),
        FusedLayerNorm(),
    ]

    # Callbacks and loggers
    file_logger = FileLogger(filename="training_log.txt", flush_interval=50)

    trainer = Trainer(
        model=composer_model,
        train_dataloader=train_loader,
        optimizers=optimizer,
        schedulers=scheduler,
        max_duration="3ep",
        device_train_microbatch_size=8,
        precision="amp_bf16",
        fsdp_config=fsdp_config,
        algorithms=algorithms,
        loggers=[file_logger],
        callbacks=[SpeedMonitor(window_size=50), LRMonitor(), MemoryMonitor()],
        save_folder="./checkpoints",
        save_interval="1ep",
        save_num_checkpoints_to_keep=2,
        seed=42,
        progress_bar=True,
    )

    trainer.fit()
    print("Training complete. Checkpoints saved to ./checkpoints/")


if __name__ == "__main__":
    main()

Launch it:

1
composer -n 4 train.py

Common Errors and Fixes

RuntimeError: ShardedTensor metadata mismatch – happens when resuming from a checkpoint saved with a different number of GPUs than you’re training with now. Composer supports elastic checkpointing, but you need to pass load_weights_only=True in the Trainer or use composer.utils.dist_checkpoint utilities to reshape shards.

CUDA out of memory even with FULL_SHARD – lower device_train_microbatch_size first. If that’s already 1, enable both activation_checkpointing and cpu_offload in the FSDP config. Also check that no other processes are using GPU memory (nvidia-smi).

ValueError: Cannot find batch size from the provided dataloader – Composer tries to auto-detect the batch size. If your DataLoader uses a custom collator or a non-standard batch key, set device_train_microbatch_size explicitly and make sure your batch dictionary has standard keys (input_ids, attention_mask, labels).

Timeout at barrier during multi-node training – one node can’t reach the master. Check firewall rules on the --master_port. Set NCCL_SOCKET_IFNAME=eth0 (or your network interface) and NCCL_DEBUG=INFO to see where communication stalls.

ImportError: cannot import name 'FusedLayerNorm' – you need triton installed. Run pip install triton. Without it, Composer falls back to the standard PyTorch LayerNorm, which is slower but functionally identical.

Checkpoints are huge – FSDP full-state checkpoints include optimizer state for every shard. Use save_weights_only=True if you only need model weights for inference. For training resumption, you need the full checkpoint.

Training hangs after first epoch – often a data loader issue. Make sure num_workers > 0 and that your dataset doesn’t have corrupted samples. Try num_workers=0 to isolate whether the hang is data-related.