The Quick Version
Mixed precision training runs most operations in half precision (float16 or bfloat16) instead of full precision (float32). This halves memory usage and doubles throughput on GPUs with tensor cores (anything newer than V100). PyTorch’s torch.amp makes it a 3-line change.
| |
| |
That’s it. Wrap your forward pass in autocast, use GradScaler for the backward pass, and you get 30-50% faster training with half the memory. The model weights stay in float32 — only the forward and backward computations use float16.
How Mixed Precision Works
Three components work together:
autocast automatically picks float16 for ops that benefit from it (matmuls, convolutions) and keeps float32 for ops that need precision (softmax, layer norm, loss functions).
GradScaler multiplies the loss by a large number before the backward pass. This prevents tiny float16 gradients from underflowing to zero. It then unscales the gradients before the optimizer step.
Master weights stay in float32. The optimizer updates float32 weights, then the next forward pass casts them to float16 on-the-fly. This preserves training stability.
| |
bfloat16 vs float16
Ampere GPUs (A100, RTX 3090) and newer support bfloat16, which has the same dynamic range as float32 but reduced precision. The advantage: no GradScaler needed.
| |
| Dtype | Range | Precision | GPU Support | GradScaler |
|---|---|---|---|---|
| float32 | Full | Full | All | Not needed |
| float16 | Limited | High | V100+ | Required |
| bfloat16 | Full | Reduced | A100+ (Ampere) | Not needed |
Use bfloat16 if your GPU supports it. It’s simpler and avoids the rare numerical issues that float16 can cause with very large or small values. Check support with:
| |
Mixed Precision for Transformer Training
Transformers benefit enormously from mixed precision because attention is dominated by matrix multiplications — exactly the ops that tensor cores accelerate.
| |
Note the gradient clipping: call scaler.unscale_() before clip_grad_norm_ so you’re clipping the actual gradient values, not the scaled ones.
Memory Savings and Benchmarks
Mixed precision roughly halves activation memory because intermediate tensors are stored in float16 instead of float32. This lets you double your batch size, which further improves GPU utilization.
| |
Typical results on RTX 3090:
- ResNet-50, bs=64, FP32: ~8.2 GB
- ResNet-50, bs=64, AMP: ~4.8 GB (42% reduction)
- ResNet-50, bs=128, AMP: ~8.5 GB (2x batch, same memory)
Common Errors and Fixes
Loss becomes NaN after a few steps
Float16 overflow. The GradScaler should handle this automatically — it detects NaN gradients and skips that optimizer step. If it persists, the issue is likely in your model, not AMP. Check for very large values in your loss computation.
| |
Training is slower with AMP, not faster
Your GPU doesn’t have tensor cores (GTX 1080, etc.), or your model is too small to benefit. Tensor cores need specific matrix dimensions (multiples of 8) to activate. Ensure batch size and hidden dimensions are multiples of 8.
Gradient clipping doesn’t work correctly
You must call scaler.unscale_(optimizer) before clip_grad_norm_. Without unscaling, you’re clipping the scaled gradients, which either clips nothing (scale too low) or clips everything (scale too high).
Model validation metrics are slightly different with AMP
Some ops produce slightly different results in float16. This is expected — differences should be less than 0.1% on final metrics. If you see larger gaps, specific layers might need float32. Force them with:
| |
Hugging Face Trainer already handles AMP
If you’re using the Transformers Trainer, just pass fp16=True or bf16=True in TrainingArguments. Don’t add manual AMP on top — the Trainer handles scaler, autocast, and gradient accumulation internally.
Related Guides
- How to Set Up Multi-GPU Training with PyTorch
- How to Build a Model Training Checkpoint Pipeline with PyTorch
- How to Scale ML Training and Inference with Ray
- How to Profile and Optimize GPU Memory for LLM Training
- How to Compile and Optimize PyTorch Models with torch.compile
- How to Build a Multi-Node Training Pipeline with Fabric and NCCL
- How to Build a Model Training Pipeline with Lightning Fabric
- How to Set Up Distributed Training with DeepSpeed and ZeRO
- How to Monitor GPU Utilization and Debug Training Bottlenecks
- How to Build a Model Training Pipeline with Composer and FSDP