The Quick Version
torch.compile is PyTorch’s built-in compiler that fuses operations, eliminates overhead, and generates optimized GPU kernels. It speeds up most models by 20-50% with a single line of code.
| |
| |
On an A100 with ResNet-50, expect compiled to be 30-40% faster than eager mode. The speedup varies by model — transformer models see even bigger gains.
Understanding Compilation Modes
torch.compile has three modes that trade compilation time for runtime speed:
| |
| Mode | Compile Time | Runtime Speed | Best For |
|---|---|---|---|
default | 30-60s | Good (20-30% faster) | General use |
max-autotune | 5-15min | Best (30-50% faster) | Production inference |
reduce-overhead | 15-30s | Good | Interactive/dynamic workloads |
For production inference where you compile once and run millions of times, max-autotune is worth the wait. For development, stick with default.
Compiling Training Loops
torch.compile isn’t just for inference. It speeds up training too, especially the forward and backward passes:
| |
The compiler traces through both the forward and backward passes, fusing operations across them. Vision Transformers (ViT) typically see 25-40% training speedup because attention operations have lots of fusion opportunities.
Compiling Custom Functions
You can compile individual functions, not just full models. This is useful for optimizing specific bottlenecks:
| |
Element-wise operations (multiply, add, exp, pow) chained together are where torch.compile shines brightest — the compiler fuses them into a single GPU kernel instead of launching separate kernels for each op.
Handling Dynamic Shapes
By default, torch.compile recompiles when input shapes change. For variable-length sequences or dynamic batch sizes, use dynamic=True:
| |
Without dynamic=True, each new shape triggers a full recompilation (30-60 seconds each). With it, the compiler generates shape-generic code.
For NLP models with variable sequence lengths, dynamic=True is essential:
| |
Debugging Compilation Issues
When torch.compile fails or produces wrong results, use these tools:
| |
explain() shows you where the compiler had to “break” the graph — points where it falls back to eager execution. Fewer graph breaks = more optimization opportunities.
Common causes of graph breaks:
| |
Exporting Compiled Models
For deployment, export the compiled model to TorchScript or ONNX:
| |
torch.export captures the full graph without graph breaks, making it suitable for deployment to environments where you can’t run torch.compile (like embedded systems or serverless functions).
Common Errors and Fixes
torch._dynamo.exc.Unsupported: call_function ...
An operation isn’t supported by the compiler. Check if you’re using a niche third-party library inside the compiled region. Move unsupported code outside the compiled function, or use torch.compiler.disable() to mark specific functions as not-compilable:
| |
Compilation takes 10+ minutes
Use mode="default" instead of mode="max-autotune". Or compile during a warmup phase so users don’t wait. For CI/CD, cache compiled artifacts with TORCHINDUCTOR_FX_GRAPH_CACHE=1.
Results differ between compiled and eager mode
Floating point reordering during fusion can cause tiny numerical differences (1e-6). This is expected and usually harmless. If you see large differences, you likely have a bug in custom ops — test them in eager mode first.
RuntimeError: shape mismatch after recompilation
You have a cached compilation from a different shape and dynamic=True isn’t set. Either enable dynamic shapes or clear the cache: torch._dynamo.reset().
Compilation fails on custom CUDA kernels
torch.compile can’t trace through raw CUDA kernels. Wrap them with torch.library.custom_op to make them compiler-compatible, or keep them outside the compiled region.
When torch.compile Helps Most
Biggest gains: transformer models, element-wise operations, repeated small ops, and code with lots of Python overhead between tensor operations.
Smallest gains: models already bottlenecked by large matrix multiplications (like big linear layers), IO-bound workloads, and models that use lots of unsupported operations.
The rule of thumb: if your GPU utilization is below 70% during inference, torch.compile will likely help by reducing CPU overhead and fusing operations. If you’re already at 90%+ GPU utilization, the gains will be modest.
Related Guides
- How to Build a Model Training Checkpoint Pipeline with PyTorch
- How to Optimize Model Inference with ONNX Runtime
- How to Set Up Multi-GPU Training with PyTorch
- How to Speed Up Training with Mixed Precision and PyTorch AMP
- How to Use PyTorch FlexAttention for Fast LLM Inference
- How to Deploy Models to Edge Devices with ONNX and TensorRT
- How to Optimize Docker Images for ML Model Serving
- How to Scale ML Training and Inference with Ray
- How to Serve ML Models with NVIDIA Triton Inference Server
- How to Profile and Optimize GPU Memory for LLM Training