The Quick Version

torch.compile is PyTorch’s built-in compiler that fuses operations, eliminates overhead, and generates optimized GPU kernels. It speeds up most models by 20-50% with a single line of code.

1
pip install torch torchvision  # PyTorch 2.0+
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
from torchvision.models import resnet50

model = resnet50(weights="IMAGENET1K_V2").cuda().eval()

# One line to compile
compiled_model = torch.compile(model)

# First call is slow (compilation happens here)
dummy = torch.randn(1, 3, 224, 224).cuda()
_ = compiled_model(dummy)

# Subsequent calls are fast
import time
x = torch.randn(32, 3, 224, 224).cuda()

start = time.time()
for _ in range(100):
    with torch.no_grad():
        _ = compiled_model(x)
torch.cuda.synchronize()
print(f"Compiled: {(time.time() - start):.2f}s")

start = time.time()
for _ in range(100):
    with torch.no_grad():
        _ = model(x)
torch.cuda.synchronize()
print(f"Eager: {(time.time() - start):.2f}s")

On an A100 with ResNet-50, expect compiled to be 30-40% faster than eager mode. The speedup varies by model — transformer models see even bigger gains.

Understanding Compilation Modes

torch.compile has three modes that trade compilation time for runtime speed:

1
2
3
4
5
6
7
8
# Default: balanced compilation time and speedup
model_default = torch.compile(model, mode="default")

# Max autotune: slower compile, fastest runtime (tries many kernel configs)
model_fast = torch.compile(model, mode="max-autotune")

# Reduce overhead: fastest compile, good for dynamic shapes
model_dynamic = torch.compile(model, mode="reduce-overhead")
ModeCompile TimeRuntime SpeedBest For
default30-60sGood (20-30% faster)General use
max-autotune5-15minBest (30-50% faster)Production inference
reduce-overhead15-30sGoodInteractive/dynamic workloads

For production inference where you compile once and run millions of times, max-autotune is worth the wait. For development, stick with default.

Compiling Training Loops

torch.compile isn’t just for inference. It speeds up training too, especially the forward and backward passes:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
import torch.nn as nn
from torch.optim import AdamW
from torchvision.models import vit_b_16

model = vit_b_16().cuda()
optimizer = AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# Compile the model (covers forward + backward)
compiled_model = torch.compile(model)

# Training loop — same as eager, just use the compiled model
for epoch in range(10):
    for batch_idx in range(100):
        x = torch.randn(32, 3, 224, 224).cuda()
        y = torch.randint(0, 1000, (32,)).cuda()

        optimizer.zero_grad()
        output = compiled_model(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

The compiler traces through both the forward and backward passes, fusing operations across them. Vision Transformers (ViT) typically see 25-40% training speedup because attention operations have lots of fusion opportunities.

Compiling Custom Functions

You can compile individual functions, not just full models. This is useful for optimizing specific bottlenecks:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
@torch.compile
def custom_loss(predictions, targets, weights):
    """Weighted focal loss — lots of element-wise ops that benefit from fusion."""
    ce_loss = torch.nn.functional.cross_entropy(predictions, targets, reduction="none")
    p_t = torch.exp(-ce_loss)
    focal_weight = (1 - p_t) ** 2.0
    weighted_loss = focal_weight * ce_loss * weights
    return weighted_loss.mean()

# The function is compiled on first call
preds = torch.randn(64, 10).cuda()
targets = torch.randint(0, 10, (64,)).cuda()
weights = torch.ones(64).cuda()

loss = custom_loss(preds, targets, weights)

Element-wise operations (multiply, add, exp, pow) chained together are where torch.compile shines brightest — the compiler fuses them into a single GPU kernel instead of launching separate kernels for each op.

Handling Dynamic Shapes

By default, torch.compile recompiles when input shapes change. For variable-length sequences or dynamic batch sizes, use dynamic=True:

1
2
3
4
5
6
model = torch.compile(model, dynamic=True)

# These use the same compiled graph (no recompilation)
out1 = model(torch.randn(1, 3, 224, 224).cuda())   # batch=1
out2 = model(torch.randn(16, 3, 224, 224).cuda())   # batch=16
out3 = model(torch.randn(32, 3, 224, 224).cuda())   # batch=32

Without dynamic=True, each new shape triggers a full recompilation (30-60 seconds each). With it, the compiler generates shape-generic code.

For NLP models with variable sequence lengths, dynamic=True is essential:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct").cuda()
model = torch.compile(model, dynamic=True)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")

# Variable-length inputs — no recompilation
for text in ["Short prompt.", "A much longer prompt that contains more tokens than the previous one."]:
    inputs = tokenizer(text, return_tensors="pt").to("cuda")
    with torch.no_grad():
        outputs = model(**inputs)

Debugging Compilation Issues

When torch.compile fails or produces wrong results, use these tools:

1
2
3
4
5
6
7
import torch._dynamo as dynamo

# See what the compiler is doing
torch._dynamo.config.verbose = True

# Find unsupported operations (graph breaks)
dynamo.explain(model)(torch.randn(1, 3, 224, 224).cuda())

explain() shows you where the compiler had to “break” the graph — points where it falls back to eager execution. Fewer graph breaks = more optimization opportunities.

Common causes of graph breaks:

1
2
3
4
5
6
7
8
9
# BAD: data-dependent control flow causes graph breaks
def forward(self, x):
    if x.sum() > 0:  # value not known at compile time
        return self.path_a(x)
    return self.path_b(x)

# GOOD: use torch.where instead
def forward(self, x):
    return torch.where(x.sum() > 0, self.path_a(x), self.path_b(x))

Exporting Compiled Models

For deployment, export the compiled model to TorchScript or ONNX:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
# Export with torch.export (PyTorch 2.1+)
from torch.export import export

model = resnet50(weights="IMAGENET1K_V2").cuda().eval()
example_input = torch.randn(1, 3, 224, 224).cuda()

exported = export(model, (example_input,))
torch.export.save(exported, "resnet50_exported.pt2")

# Load and run
loaded = torch.export.load("resnet50_exported.pt2")
output = loaded.module()(example_input)

torch.export captures the full graph without graph breaks, making it suitable for deployment to environments where you can’t run torch.compile (like embedded systems or serverless functions).

Common Errors and Fixes

torch._dynamo.exc.Unsupported: call_function ...

An operation isn’t supported by the compiler. Check if you’re using a niche third-party library inside the compiled region. Move unsupported code outside the compiled function, or use torch.compiler.disable() to mark specific functions as not-compilable:

1
2
3
4
@torch.compiler.disable
def legacy_preprocessing(x):
    # This function won't be compiled
    return some_unsupported_op(x)

Compilation takes 10+ minutes

Use mode="default" instead of mode="max-autotune". Or compile during a warmup phase so users don’t wait. For CI/CD, cache compiled artifacts with TORCHINDUCTOR_FX_GRAPH_CACHE=1.

Results differ between compiled and eager mode

Floating point reordering during fusion can cause tiny numerical differences (1e-6). This is expected and usually harmless. If you see large differences, you likely have a bug in custom ops — test them in eager mode first.

RuntimeError: shape mismatch after recompilation

You have a cached compilation from a different shape and dynamic=True isn’t set. Either enable dynamic shapes or clear the cache: torch._dynamo.reset().

Compilation fails on custom CUDA kernels

torch.compile can’t trace through raw CUDA kernels. Wrap them with torch.library.custom_op to make them compiler-compatible, or keep them outside the compiled region.

When torch.compile Helps Most

Biggest gains: transformer models, element-wise operations, repeated small ops, and code with lots of Python overhead between tensor operations.

Smallest gains: models already bottlenecked by large matrix multiplications (like big linear layers), IO-bound workloads, and models that use lots of unsupported operations.

The rule of thumb: if your GPU utilization is below 70% during inference, torch.compile will likely help by reducing CPU overhead and fusing operations. If you’re already at 90%+ GPU utilization, the gains will be modest.