MosaicML Composer wraps PyTorch’s Fully Sharded Data Parallel (FSDP) with a clean Trainer interface and pluggable speed-up algorithms. You get gradient accumulation, mixed precision, checkpoint management, and logging without wiring any of it yourself. The payoff: you write a training config, point it at a HuggingFace model, and Composer handles the distributed orchestration.
Install the stack:
| |
Verify FSDP support is available (requires PyTorch 2.0+):
| |
Setting Up a Composer Trainer with FSDP
Composer’s Trainer takes a fsdp_config dictionary that maps directly to PyTorch FSDP options. Here’s a minimal setup that shards a BERT model across all available GPUs:
| |
Key things happening here:
sharding_strategy: FULL_SHARDpartitions parameters, gradients, and optimizer states across GPUs. This is equivalent to DeepSpeed ZeRO-3.device_train_microbatch_sizecontrols the per-GPU micro batch. Composer handles gradient accumulation automatically if this is smaller than the DataLoader batch size.precision: amp_bf16uses bf16 automatic mixed precision. On Ampere GPUs or newer, this is faster and more stable than fp16.activation_checkpointing: Truerecomputes activations during backward to save memory, at the cost of roughly 30% more compute.
FSDP Sharding Strategies
Composer exposes three sharding strategies through the fsdp_config:
FULL_SHARD
Parameters, gradients, and optimizer states are all sharded. Maximum memory savings, highest communication overhead. Use this when your model doesn’t fit on a single GPU even with mixed precision.
| |
SHARD_GRAD_OP
Only gradients and optimizer states are sharded. Parameters stay replicated. Less communication than FULL_SHARD, but uses more memory. This is equivalent to DeepSpeed ZeRO-2 and works well when the model fits on each GPU but optimizer memory is the bottleneck.
| |
NO_SHARD
No sharding – standard DDP behavior. Useful for debugging or when you have plenty of GPU memory and want maximum throughput without any FSDP overhead.
Pick your strategy based on this rule of thumb: start with SHARD_GRAD_OP. If you OOM, switch to FULL_SHARD. If you still OOM, add cpu_offload: True and activation_checkpointing: True.
Speed-Up Algorithms
Composer ships with a library of algorithms that speed up training or reduce memory usage. You plug them into the Trainer as a list:
| |
Notable algorithms worth trying:
GradientClipping– prevents exploding gradients. Essential for large models.FusedLayerNorm– replaces standard LayerNorm with a fused CUDA kernel. Free speed on NVIDIA GPUs.LayerFreezing– progressively freezes early layers during training. Reduces compute as training progresses.
Callbacks for Logging, Checkpointing, and Early Stopping
Composer’s callback system hooks into every stage of the training loop. Here’s how to set up the three you’ll always want:
| |
What each callback does:
SpeedMonitorlogs throughput (samples/sec, tokens/sec, FLOPS) – critical for spotting regressions.LRMonitorlogs the learning rate schedule so you can verify warmup and decay are working.MemoryMonitortracks GPU memory usage. Helps you decide if you can increase micro batch size.EarlyStopperhalts training if the monitored metric doesn’t improve forpatienceepochs.
Checkpointing is built into the Trainer itself. The save_folder, save_interval, and save_num_checkpoints_to_keep parameters handle rotation automatically. Composer saves FSDP-aware checkpoints that you can resume from without manually gathering sharded weights.
Launching Multi-GPU Training
Composer includes its own launcher. Save your training script as train.py and run:
| |
Alternatively, use torchrun if you prefer the standard PyTorch launcher:
| |
Both work. Composer’s launcher sets some environment variables that its Trainer picks up automatically, so it’s slightly less config to deal with.
For multi-node runs, make sure:
- All nodes can reach the master address on the specified port.
- NCCL is using the right network interface. Set
NCCL_SOCKET_IFNAMEif you have multiple NICs. - The same code and data paths exist on every node.
Profiling Training Performance
Composer integrates with PyTorch’s profiler through a callback:
| |
After training, view traces in Chrome’s chrome://tracing or TensorBoard. Look for:
- Large gaps between GPU kernels – indicates CPU bottleneck or slow data loading.
- All-reduce taking more than 20% of step time – your communication overhead is too high. Try
SHARD_GRAD_OPinstead ofFULL_SHARD, or increase micro batch size. - Memory spikes during forward pass – activation checkpointing isn’t covering all layers, or your batch is too large.
Complete Training Script
Here’s a full script that ties everything together – FSDP sharding, speed-up algorithms, callbacks, checkpointing, and logging:
| |
Launch it:
| |
Common Errors and Fixes
RuntimeError: ShardedTensor metadata mismatch – happens when resuming from a checkpoint saved with a different number of GPUs than you’re training with now. Composer supports elastic checkpointing, but you need to pass load_weights_only=True in the Trainer or use composer.utils.dist_checkpoint utilities to reshape shards.
CUDA out of memory even with FULL_SHARD – lower device_train_microbatch_size first. If that’s already 1, enable both activation_checkpointing and cpu_offload in the FSDP config. Also check that no other processes are using GPU memory (nvidia-smi).
ValueError: Cannot find batch size from the provided dataloader – Composer tries to auto-detect the batch size. If your DataLoader uses a custom collator or a non-standard batch key, set device_train_microbatch_size explicitly and make sure your batch dictionary has standard keys (input_ids, attention_mask, labels).
Timeout at barrier during multi-node training – one node can’t reach the master. Check firewall rules on the --master_port. Set NCCL_SOCKET_IFNAME=eth0 (or your network interface) and NCCL_DEBUG=INFO to see where communication stalls.
ImportError: cannot import name 'FusedLayerNorm' – you need triton installed. Run pip install triton. Without it, Composer falls back to the standard PyTorch LayerNorm, which is slower but functionally identical.
Checkpoints are huge – FSDP full-state checkpoints include optimizer state for every shard. Use save_weights_only=True if you only need model weights for inference. For training resumption, you need the full checkpoint.
Training hangs after first epoch – often a data loader issue. Make sure num_workers > 0 and that your dataset doesn’t have corrupted samples. Try num_workers=0 to isolate whether the hang is data-related.
Related Guides
- How to Build a Model Training Pipeline with Lightning Fabric
- How to Set Up Distributed Training with DeepSpeed and ZeRO
- How to Build a Model Training Dashboard with TensorBoard and Prometheus
- How to Build a Model Training Checkpoint Pipeline with PyTorch
- How to Build a Model Artifact Signing and Verification Pipeline
- How to Build a Model Training Queue with Redis and Worker Pools
- How to Build a Model Artifact Pipeline with ORAS and Container Registries
- How to Build a Multi-Node Training Pipeline with Fabric and NCCL
- How to Build a Model Serving Cluster with Ray Serve and Docker
- How to Set Up Multi-GPU Training with PyTorch