Start by Measuring
Before you optimize anything, you need to know where your VRAM is going. PyTorch gives you the tools to answer that question precisely. Here’s a profiling script you can drop into any training run.
| |
This tells you three things: how much VRAM the model weights consume, how much activations add during forward, and how much gradients add during backward. For Llama 3.1 8B in bf16, expect roughly 16 GB for weights, plus another 8-12 GB for activations and gradients depending on sequence length.
Capture Memory Snapshots for Visual Debugging
memory_summary() gives you a snapshot in time. When you need to see every allocation over an entire training loop, use PyTorch’s memory snapshot tool. It records every malloc and free on the GPU and exports a file you can visualize.
| |
Open https://pytorch.org/memory_viz in a browser, drag in the pickle file, and you get an interactive timeline of every GPU allocation. You can hover over blocks to see stack traces, spot fragmentation, and identify which operations consume the most VRAM. This is the single most useful debugging tool when you’re hitting OOM errors and don’t know why.
Where GPU Memory Actually Goes
For a typical LLM training run, memory breaks down like this:
| Component | Approximate Cost (bf16, 7B model) |
|---|---|
| Model weights | ~14 GB |
| Gradients | ~14 GB |
| Optimizer states (AdamW) | ~28 GB (2 states per param) |
| Activations | 4-16 GB (depends on batch/seq length) |
| Total | 60-72 GB |
That’s why a 7B model can’t train in fp32 on an 80 GB A100. AdamW is the biggest offender – it stores two fp32 state tensors (momentum and variance) per parameter. The optimizer alone eats twice the memory of the model weights.
Gradient Checkpointing
Gradient checkpointing trades compute for memory by discarding intermediate activations during forward and recomputing them during backward. This cuts activation memory from O(n) to O(sqrt(n)) for n layers, typically saving 40-60% of activation memory at the cost of ~25% slower training.
For HuggingFace models, it’s one line:
| |
For custom models using raw PyTorch, wrap individual blocks with torch.utils.checkpoint.checkpoint:
| |
Set use_reentrant=False – it’s the recommended default since PyTorch 2.4 and avoids subtle bugs with models that have unused parameters.
Mixed Precision Training
Training in bf16 instead of fp32 cuts weight and activation memory in half. bf16 is better than fp16 for LLMs because it has the same exponent range as fp32, which means fewer overflow issues and no need for loss scaling in most cases.
| |
If you’re writing a custom training loop, use torch.autocast:
| |
On Ampere GPUs (A100, RTX 3090) and newer, bf16 is the right choice. Use fp16 only if you’re stuck on older Volta or Turing hardware.
8-bit Optimizers with bitsandbytes
AdamW’s optimizer states are often the largest memory consumer. Standard AdamW stores two fp32 tensors per parameter. For a 7B model, that’s 56 GB just for the optimizer. Switching to 8-bit AdamW from bitsandbytes cuts that to ~14 GB with almost no training quality loss.
| |
| |
If you’re using HuggingFace Trainer, set it in TrainingArguments:
| |
This is the highest-ROI optimization on this list. You lose essentially nothing in terms of convergence quality and save 75% of optimizer memory.
Gradient Accumulation
When you can’t fit a large batch into VRAM, gradient accumulation simulates it by running multiple smaller forward/backward passes before updating weights. Memory usage matches the small batch size, but the effective batch size is micro_batch * accumulation_steps.
| |
The key detail people miss: divide the loss by accumulation_steps. Without this, your gradients will be accumulation_steps times larger than they should be, and training will diverge.
Putting It All Together: VRAM Budget for a 7B Model
Here’s what each optimization saves on a Llama-class 7B model:
| Configuration | Est. VRAM | Hardware |
|---|---|---|
| Full fp32, AdamW, batch=4 | ~120 GB | 2x A100 80GB |
| bf16 weights + fp32 optimizer | ~72 GB | 1x A100 80GB |
| bf16 + gradient checkpointing | ~60 GB | 1x A100 80GB |
| bf16 + checkpointing + AdamW 8-bit | ~40 GB | 1x A100 40GB |
| bf16 + checkpointing + AdamW 8-bit + batch=1 + accum=16 | ~24 GB | RTX 4090 / RTX 3090 |
The bottom line: stacking bf16, gradient checkpointing, 8-bit optimizer, and gradient accumulation takes you from needing two A100s to training on a single consumer GPU. The training will be slower, but it works.
Common Errors and Fixes
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 GiB
The most common error in LLM training. PyTorch is asking for a contiguous block of memory that doesn’t exist. First, check what’s actually allocated:
| |
If “reserved” is much larger than “allocated,” you have fragmentation. Set the environment variable before your script runs:
| |
This lets the allocator grow segments incrementally instead of requesting large contiguous blocks, which reduces fragmentation significantly on PyTorch 2.1+.
RuntimeError: Expected all tensors to be on the same device
This happens when gradient checkpointing moves some tensors to CPU. Make sure your input tensors and model are on the same device. With HuggingFace models using device_map="auto", some layers may land on CPU if VRAM is tight. Force everything to GPU:
| |
ValueError: fp16 mixed precision requires a GPU with compute capability >= 7.0
Your GPU doesn’t support fp16 at hardware speed. Switch to bf16 if you’re on Ampere+ hardware, or use fp32 on older GPUs. Check compute capability:
| |
Loss is NaN after enabling mixed precision.
fp16 is more prone to this than bf16 due to its limited range. Switch to bf16 if your hardware supports it. If you must use fp16, make sure you’re using GradScaler:
| |
bitsandbytes throws CUDA Setup failed on import.
The library can’t find your CUDA installation. Set the path explicitly:
| |
Verify with:
| |
Training slows down dramatically after enabling gradient checkpointing.
A 20-30% slowdown is normal. More than 50% suggests you’re checkpointing too aggressively. With HuggingFace models, the checkpointing granularity is per transformer layer, which is the right balance. If you’re wrapping custom code, checkpoint entire blocks rather than individual operations.
Related Guides
- How to Set Up Multi-GPU Training with PyTorch
- How to Optimize LLM Serving with KV Cache and PagedAttention
- How to Set Up Distributed Training with DeepSpeed and ZeRO
- How to Speed Up Training with Mixed Precision and PyTorch AMP
- How to Build a Model Training Pipeline with Lightning Fabric
- How to Scale ML Training and Inference with Ray
- How to Monitor GPU Utilization and Debug Training Bottlenecks
- How to Build a Model Training Checkpoint Pipeline with PyTorch
- How to Build a Model Training Pipeline with Composer and FSDP
- How to Use PyTorch FlexAttention for Fast LLM Inference