The Quick Version

Mixed precision training runs most operations in half precision (float16 or bfloat16) instead of full precision (float32). This halves memory usage and doubles throughput on GPUs with tensor cores (anything newer than V100). PyTorch’s torch.amp makes it a 3-line change.

1
pip install torch torchvision
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import torch.nn as nn
from torch.amp import autocast, GradScaler
from torchvision.models import resnet50

model = resnet50().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
scaler = GradScaler()  # handles loss scaling for float16

for epoch in range(10):
    for batch_idx in range(100):
        inputs = torch.randn(64, 3, 224, 224).cuda()
        targets = torch.randint(0, 1000, (64,)).cuda()

        optimizer.zero_grad()

        # autocast runs eligible ops in float16
        with autocast(device_type="cuda"):
            outputs = model(inputs)
            loss = criterion(outputs, targets)

        # GradScaler prevents underflow in float16 gradients
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

    print(f"Epoch {epoch}: loss = {loss.item():.4f}")

That’s it. Wrap your forward pass in autocast, use GradScaler for the backward pass, and you get 30-50% faster training with half the memory. The model weights stay in float32 — only the forward and backward computations use float16.

How Mixed Precision Works

Three components work together:

  1. autocast automatically picks float16 for ops that benefit from it (matmuls, convolutions) and keeps float32 for ops that need precision (softmax, layer norm, loss functions).

  2. GradScaler multiplies the loss by a large number before the backward pass. This prevents tiny float16 gradients from underflowing to zero. It then unscales the gradients before the optimizer step.

  3. Master weights stay in float32. The optimizer updates float32 weights, then the next forward pass casts them to float16 on-the-fly. This preserves training stability.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
# What autocast does under the hood:
# These run in float16 (fast on tensor cores):
#   - nn.Linear, nn.Conv2d, nn.ConvTranspose2d
#   - torch.matmul, torch.bmm, torch.mm
#   - F.linear, F.conv2d

# These stay in float32 (need precision):
#   - nn.LayerNorm, nn.GroupNorm, nn.BatchNorm
#   - F.softmax, F.log_softmax
#   - F.cross_entropy, F.binary_cross_entropy
#   - torch.sum, torch.mean (reductions)

bfloat16 vs float16

Ampere GPUs (A100, RTX 3090) and newer support bfloat16, which has the same dynamic range as float32 but reduced precision. The advantage: no GradScaler needed.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
# bfloat16 — simpler, no scaler required
model = resnet50().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

for epoch in range(10):
    for batch_idx in range(100):
        inputs = torch.randn(64, 3, 224, 224).cuda()
        targets = torch.randint(0, 1000, (64,)).cuda()

        optimizer.zero_grad()

        with autocast(device_type="cuda", dtype=torch.bfloat16):
            outputs = model(inputs)
            loss = criterion(outputs, targets)

        loss.backward()  # no scaler needed!
        optimizer.step()

    print(f"Epoch {epoch}: loss = {loss.item():.4f}")
DtypeRangePrecisionGPU SupportGradScaler
float32FullFullAllNot needed
float16LimitedHighV100+Required
bfloat16FullReducedA100+ (Ampere)Not needed

Use bfloat16 if your GPU supports it. It’s simpler and avoids the rare numerical issues that float16 can cause with very large or small values. Check support with:

1
print(f"bfloat16 supported: {torch.cuda.is_bf16_supported()}")

Mixed Precision for Transformer Training

Transformers benefit enormously from mixed precision because attention is dominated by matrix multiplications — exactly the ops that tensor cores accelerate.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from torch.amp import autocast, GradScaler
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import DataLoader

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B").cuda()
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
scaler = GradScaler()

# Gradient accumulation + mixed precision
accumulation_steps = 4

model.train()
for step in range(1000):
    input_ids = torch.randint(0, 32000, (2, 512)).cuda()  # your real data here
    attention_mask = torch.ones_like(input_ids)

    with autocast(device_type="cuda"):
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
        loss = outputs.loss / accumulation_steps

    scaler.scale(loss).backward()

    if (step + 1) % accumulation_steps == 0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

    if (step + 1) % 100 == 0:
        print(f"Step {step+1}, Loss: {loss.item() * accumulation_steps:.4f}")

Note the gradient clipping: call scaler.unscale_() before clip_grad_norm_ so you’re clipping the actual gradient values, not the scaled ones.

Memory Savings and Benchmarks

Mixed precision roughly halves activation memory because intermediate tensors are stored in float16 instead of float32. This lets you double your batch size, which further improves GPU utilization.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import gc

def measure_memory(model, batch_size: int, use_amp: bool) -> dict:
    """Measure peak GPU memory for a forward + backward pass."""
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    gc.collect()

    inputs = torch.randn(batch_size, 3, 224, 224).cuda()
    targets = torch.randint(0, 1000, (batch_size,)).cuda()
    criterion = nn.CrossEntropyLoss()

    if use_amp:
        scaler = GradScaler()
        with autocast(device_type="cuda"):
            output = model(inputs)
            loss = criterion(output, targets)
        scaler.scale(loss).backward()
    else:
        output = model(inputs)
        loss = criterion(output, targets)
        loss.backward()

    peak_mb = torch.cuda.max_memory_allocated() / 1024**2
    return {"batch_size": batch_size, "amp": use_amp, "peak_memory_mb": round(peak_mb)}

model = resnet50().cuda()

fp32 = measure_memory(model, batch_size=64, use_amp=False)
amp = measure_memory(model, batch_size=64, use_amp=True)
amp_2x = measure_memory(model, batch_size=128, use_amp=True)

print(f"FP32 (bs=64):  {fp32['peak_memory_mb']} MB")
print(f"AMP  (bs=64):  {amp['peak_memory_mb']} MB")
print(f"AMP  (bs=128): {amp_2x['peak_memory_mb']} MB")

Typical results on RTX 3090:

  • ResNet-50, bs=64, FP32: ~8.2 GB
  • ResNet-50, bs=64, AMP: ~4.8 GB (42% reduction)
  • ResNet-50, bs=128, AMP: ~8.5 GB (2x batch, same memory)

Common Errors and Fixes

Loss becomes NaN after a few steps

Float16 overflow. The GradScaler should handle this automatically — it detects NaN gradients and skips that optimizer step. If it persists, the issue is likely in your model, not AMP. Check for very large values in your loss computation.

1
2
3
4
# GradScaler automatically handles NaN by skipping the step
# You can monitor it:
if scaler.get_scale() < 1.0:
    print("Warning: scaler reduced significantly, possible overflow issues")

Training is slower with AMP, not faster

Your GPU doesn’t have tensor cores (GTX 1080, etc.), or your model is too small to benefit. Tensor cores need specific matrix dimensions (multiples of 8) to activate. Ensure batch size and hidden dimensions are multiples of 8.

Gradient clipping doesn’t work correctly

You must call scaler.unscale_(optimizer) before clip_grad_norm_. Without unscaling, you’re clipping the scaled gradients, which either clips nothing (scale too low) or clips everything (scale too high).

Model validation metrics are slightly different with AMP

Some ops produce slightly different results in float16. This is expected — differences should be less than 0.1% on final metrics. If you see larger gaps, specific layers might need float32. Force them with:

1
2
3
4
5
6
with autocast(device_type="cuda"):
    x = model.encoder(inputs)  # runs in float16

# Force this part in float32
with autocast(device_type="cuda", enabled=False):
    output = model.sensitive_head(x.float())

Hugging Face Trainer already handles AMP

If you’re using the Transformers Trainer, just pass fp16=True or bf16=True in TrainingArguments. Don’t add manual AMP on top — the Trainer handles scaler, autocast, and gradient accumulation internally.