Start by Measuring

Before you optimize anything, you need to know where your VRAM is going. PyTorch gives you the tools to answer that question precisely. Here’s a profiling script you can drop into any training run.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer

def print_memory_stats(label: str):
    """Print current GPU memory usage with a label."""
    allocated = torch.cuda.memory_allocated() / 1024**3
    reserved = torch.cuda.memory_reserved() / 1024**3
    peak = torch.cuda.max_memory_allocated() / 1024**3
    print(f"[{label}]")
    print(f"  Allocated: {allocated:.2f} GB")
    print(f"  Reserved:  {reserved:.2f} GB")
    print(f"  Peak:      {peak:.2f} GB")
    print()

torch.cuda.reset_peak_memory_stats()

# Load a model
model_id = "meta-llama/Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id, torch_dtype=torch.bfloat16, device_map="cuda"
)
print_memory_stats("After model load")

# Simulate a forward pass
inputs = tokenizer("The quick brown fox", return_tensors="pt").to("cuda")
with torch.no_grad():
    outputs = model(**inputs)
print_memory_stats("After forward pass (no grad)")

# Forward + backward
model.train()
inputs = tokenizer("The quick brown fox", return_tensors="pt").to("cuda")
outputs = model(**inputs, labels=inputs["input_ids"])
outputs.loss.backward()
print_memory_stats("After forward + backward")

# Full memory summary from PyTorch allocator
print(torch.cuda.memory_summary(abbreviated=True))

This tells you three things: how much VRAM the model weights consume, how much activations add during forward, and how much gradients add during backward. For Llama 3.1 8B in bf16, expect roughly 16 GB for weights, plus another 8-12 GB for activations and gradients depending on sequence length.

Capture Memory Snapshots for Visual Debugging

memory_summary() gives you a snapshot in time. When you need to see every allocation over an entire training loop, use PyTorch’s memory snapshot tool. It records every malloc and free on the GPU and exports a file you can visualize.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch

# Start recording memory events
torch.cuda.memory._record_memory_history(max_entries=100_000)

# --- your training code here ---
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

for step in range(5):
    inputs = tokenizer(
        "Sample training text for profiling purposes",
        return_tensors="pt", max_length=512, truncation=True, padding="max_length"
    ).to("cuda")
    outputs = model(**inputs, labels=inputs["input_ids"])
    outputs.loss.backward()
    optimizer.step()
    optimizer.zero_grad()

# Dump the snapshot
torch.cuda.memory._dump_snapshot("memory_snapshot.pickle")

# Stop recording
torch.cuda.memory._record_memory_history(enabled=None)

print("Snapshot saved to memory_snapshot.pickle")
print("Upload to https://pytorch.org/memory_viz to visualize")

Open https://pytorch.org/memory_viz in a browser, drag in the pickle file, and you get an interactive timeline of every GPU allocation. You can hover over blocks to see stack traces, spot fragmentation, and identify which operations consume the most VRAM. This is the single most useful debugging tool when you’re hitting OOM errors and don’t know why.

Where GPU Memory Actually Goes

For a typical LLM training run, memory breaks down like this:

ComponentApproximate Cost (bf16, 7B model)
Model weights~14 GB
Gradients~14 GB
Optimizer states (AdamW)~28 GB (2 states per param)
Activations4-16 GB (depends on batch/seq length)
Total60-72 GB

That’s why a 7B model can’t train in fp32 on an 80 GB A100. AdamW is the biggest offender – it stores two fp32 state tensors (momentum and variance) per parameter. The optimizer alone eats twice the memory of the model weights.

Gradient Checkpointing

Gradient checkpointing trades compute for memory by discarding intermediate activations during forward and recomputing them during backward. This cuts activation memory from O(n) to O(sqrt(n)) for n layers, typically saving 40-60% of activation memory at the cost of ~25% slower training.

For HuggingFace models, it’s one line:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
from transformers import AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B",
    torch_dtype=torch.bfloat16,
    device_map="cuda",
)

# Enable gradient checkpointing
model.gradient_checkpointing_enable()

# Verify it's active
print(f"Gradient checkpointing: {model.is_gradient_checkpointing}")
# Output: Gradient checkpointing: True

For custom models using raw PyTorch, wrap individual blocks with torch.utils.checkpoint.checkpoint:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
from torch.utils.checkpoint import checkpoint

class TransformerBlock(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.attn = nn.MultiheadAttention(hidden_dim, 8)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.GELU(),
            nn.Linear(hidden_dim * 4, hidden_dim),
        )
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)

    def forward(self, x):
        # Wrap the forward logic so activations get recomputed during backward
        def block_fn(x):
            h = self.norm1(x)
            h = x + self.attn(h, h, h)[0]
            h = h + self.ffn(self.norm2(h))
            return h
        return checkpoint(block_fn, x, use_reentrant=False)

Set use_reentrant=False – it’s the recommended default since PyTorch 2.4 and avoids subtle bugs with models that have unused parameters.

Mixed Precision Training

Training in bf16 instead of fp32 cuts weight and activation memory in half. bf16 is better than fp16 for LLMs because it has the same exponent range as fp32, which means fewer overflow issues and no need for loss scaling in most cases.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./output",
    bf16=True,                # Use bfloat16 mixed precision
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-5,
    num_train_epochs=1,
    logging_steps=10,
)

If you’re writing a custom training loop, use torch.autocast:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
scaler = torch.amp.GradScaler("cuda")  # Only needed for fp16, not bf16

for batch in dataloader:
    optimizer.zero_grad()
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        outputs = model(**batch)
        loss = outputs.loss
    # For bf16 on Ampere+, you can skip the scaler entirely
    loss.backward()
    optimizer.step()

On Ampere GPUs (A100, RTX 3090) and newer, bf16 is the right choice. Use fp16 only if you’re stuck on older Volta or Turing hardware.

8-bit Optimizers with bitsandbytes

AdamW’s optimizer states are often the largest memory consumer. Standard AdamW stores two fp32 tensors per parameter. For a 7B model, that’s 56 GB just for the optimizer. Switching to 8-bit AdamW from bitsandbytes cuts that to ~14 GB with almost no training quality loss.

1
pip install bitsandbytes
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
import bitsandbytes as bnb

# Drop-in replacement for torch.optim.AdamW
optimizer = bnb.optim.AdamW8bit(
    model.parameters(),
    lr=2e-5,
    betas=(0.9, 0.999),
    weight_decay=0.01,
)

# Training loop is identical to standard AdamW
for batch in dataloader:
    outputs = model(**batch)
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

If you’re using HuggingFace Trainer, set it in TrainingArguments:

1
2
3
4
5
6
7
training_args = TrainingArguments(
    output_dir="./output",
    optim="adamw_bnb_8bit",      # 8-bit AdamW via bitsandbytes
    bf16=True,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
)

This is the highest-ROI optimization on this list. You lose essentially nothing in terms of convergence quality and save 75% of optimizer memory.

Gradient Accumulation

When you can’t fit a large batch into VRAM, gradient accumulation simulates it by running multiple smaller forward/backward passes before updating weights. Memory usage matches the small batch size, but the effective batch size is micro_batch * accumulation_steps.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
accumulation_steps = 8
micro_batch_size = 2  # fits in VRAM
# effective batch size = 2 * 8 = 16

optimizer.zero_grad()
for i, batch in enumerate(dataloader):
    outputs = model(**batch)
    loss = outputs.loss / accumulation_steps  # normalize the loss
    loss.backward()

    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

The key detail people miss: divide the loss by accumulation_steps. Without this, your gradients will be accumulation_steps times larger than they should be, and training will diverge.

Putting It All Together: VRAM Budget for a 7B Model

Here’s what each optimization saves on a Llama-class 7B model:

ConfigurationEst. VRAMHardware
Full fp32, AdamW, batch=4~120 GB2x A100 80GB
bf16 weights + fp32 optimizer~72 GB1x A100 80GB
bf16 + gradient checkpointing~60 GB1x A100 80GB
bf16 + checkpointing + AdamW 8-bit~40 GB1x A100 40GB
bf16 + checkpointing + AdamW 8-bit + batch=1 + accum=16~24 GBRTX 4090 / RTX 3090

The bottom line: stacking bf16, gradient checkpointing, 8-bit optimizer, and gradient accumulation takes you from needing two A100s to training on a single consumer GPU. The training will be slower, but it works.

Common Errors and Fixes

torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 GiB

The most common error in LLM training. PyTorch is asking for a contiguous block of memory that doesn’t exist. First, check what’s actually allocated:

1
print(torch.cuda.memory_summary())

If “reserved” is much larger than “allocated,” you have fragmentation. Set the environment variable before your script runs:

1
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

This lets the allocator grow segments incrementally instead of requesting large contiguous blocks, which reduces fragmentation significantly on PyTorch 2.1+.

RuntimeError: Expected all tensors to be on the same device

This happens when gradient checkpointing moves some tensors to CPU. Make sure your input tensors and model are on the same device. With HuggingFace models using device_map="auto", some layers may land on CPU if VRAM is tight. Force everything to GPU:

1
2
3
model = AutoModelForCausalLM.from_pretrained(
    model_id, torch_dtype=torch.bfloat16, device_map="cuda"  # not "auto"
)

ValueError: fp16 mixed precision requires a GPU with compute capability >= 7.0

Your GPU doesn’t support fp16 at hardware speed. Switch to bf16 if you’re on Ampere+ hardware, or use fp32 on older GPUs. Check compute capability:

1
2
print(torch.cuda.get_device_capability())
# (8, 0) = Ampere, (9, 0) = Hopper, (7, 0) = Volta

Loss is NaN after enabling mixed precision.

fp16 is more prone to this than bf16 due to its limited range. Switch to bf16 if your hardware supports it. If you must use fp16, make sure you’re using GradScaler:

1
2
3
4
5
6
scaler = torch.amp.GradScaler("cuda")
with torch.autocast(device_type="cuda", dtype=torch.float16):
    loss = model(**batch).loss
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

bitsandbytes throws CUDA Setup failed on import.

The library can’t find your CUDA installation. Set the path explicitly:

1
2
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
pip install bitsandbytes --force-reinstall --no-cache-dir

Verify with:

1
2
import bitsandbytes as bnb
print(bnb.cuda_setup.main())

Training slows down dramatically after enabling gradient checkpointing.

A 20-30% slowdown is normal. More than 50% suggests you’re checkpointing too aggressively. With HuggingFace models, the checkpointing granularity is per transformer layer, which is the right balance. If you’re wrapping custom code, checkpoint entire blocks rather than individual operations.