PyTorch’s FlexAttention API lets you write custom attention patterns in plain Python and have torch.compile fuse them into a single optimized Triton kernel. No hand-written CUDA. No messing with FlashAttention forks. You define a score_mod or mask_mod function, and PyTorch generates a kernel that runs within 10% of FlashAttention 2 on Ampere GPUs.
Since PyTorch 2.5, FlexAttention has been available for training. The real news is the FlexDecoding backend – optimized specifically for inference workloads where you have a single query token attending to a long KV cache. In gpt-fast benchmarks, FlexDecoding delivers 1.22x to 2.04x speedup over SDPA on LLaMA 3.1-8B, and up to 1.66x on the 70B model at 16k context length.
Quick Start: Causal Decoding in 20 Lines
Here is the minimal setup for autoregressive decoding with FlexAttention. This works on PyTorch 2.5+ (the current stable release is 2.10).
| |
PyTorch automatically selects the FlexDecoding backend when it detects a short query (1-2 tokens) against a long KV sequence. No flag to flip – it just works.
The score_mod and mask_mod APIs
FlexAttention exposes two hooks into the attention computation:
score_mod(score, b, h, q_idx, kv_idx) -> Tensor– Modifies the raw dot-product score before softmax. Thescoreargument is a scalar tensor. Return the modified score.mask_mod(b, h, q_idx, kv_idx) -> bool– ReturnsTrueif position(q_idx, kv_idx)should attend. Positions returningFalseget masked out.
You can use either one, or both. mask_mod is more efficient when your pattern is purely boolean (causal, sliding window, block-sparse) because PyTorch can build a BlockMask and skip entire 128x128 blocks of computation.
| |
The BlockMask is precomputed once and reused. During decoding, you slice it to match your current sequence position. The default block size is 128 – do not change it unless you have a specific reason, since non-128 block sizes can trigger compilation failures (see Troubleshooting below).
Grouped Query Attention (GQA)
Most production models (LLaMA 3, Mistral, Gemma 2) use GQA where fewer KV heads are shared across query heads. FlexAttention handles this natively with enable_gqa=True:
| |
Without enable_gqa=True, you would need to manually repeat_interleave the KV tensors, which wastes memory. The flag tells FlexAttention to handle the broadcasting internally in the fused kernel.
Offset Wrapping for Decoding Loops
A common pattern is wrapping an existing score_mod or mask_mod to add positional offset support. This keeps your attention logic clean and reusable:
| |
The critical detail: offset must be a CUDA tensor updated with .fill_(). If you pass a Python integer, torch.compile treats it as a constant and recompiles the kernel every single step. That turns a 2ms decode step into a 30-second recompilation.
Troubleshooting
“BackendCompilerFailed: NoValidChoicesError” with BlockMask
This happens when you set a non-default BLOCK_SIZE in create_block_mask. Stick with the default 128:
| |
Precision Drift with torch.set_float32_matmul_precision(‘high’)
When you combine torch.compile + FlexAttention + torch.set_float32_matmul_precision('high'), outputs can diverge from SDPA by up to 11% relative difference. This is because the TF32 precision path interacts poorly with the fused kernel.
Fix: Use 'highest' precision, or just run in bfloat16 (which sidesteps the float32 matmul path entirely):
| |
Segfault or “LLVM ERROR: Failed to compute parent layout” on Ampere GPUs
Mixed precision (torch.autocast) combined with compiled FlexAttention can crash on some Ampere GPUs (RTX 3090, RTX 3080). The error looks like:
| |
Or simply a segmentation fault with no useful traceback. Workarounds:
- Cast your tensors to bfloat16 explicitly instead of using
torch.autocast. - Run without
torch.compile(you lose the kernel fusion, but it won’t crash). - Update to PyTorch nightly – several of these Triton codegen bugs have been patched in recent builds.
No Backward Support on CPU
FlexAttention only supports backward passes on CUDA. If you try to compute gradients on CPU, you get:
| |
This is by design – the fused Triton kernel is GPU-only. For CPU inference (forward pass only), the eager fallback works fine without compilation.
When to Use FlexAttention vs. FlashAttention
Use FlexAttention when:
- You need custom attention patterns (sliding window, document masking, ALiBi, soft-capping) and do not want to maintain custom CUDA kernels.
- You are already using
torch.compilein your inference pipeline. - You want GQA + PagedAttention without pulling in a separate library.
Stick with FlashAttention / SDPA when:
- You only need standard causal or full attention with no modifications. SDPA already dispatches to FlashAttention under the hood.
- You are on hardware without Triton support (older GPUs, non-NVIDIA).
- You need maximum absolute performance on Hopper – FlashAttention 3 is still about 25% faster than FlexAttention on H100.
Performance Numbers to Expect
These benchmarks come from the PyTorch team’s gpt-fast evaluations with LLaMA 3.1 models on NVIDIA GPUs:
| Model | Context Length | FlexDecoding vs SDPA |
|---|---|---|
| LLaMA 3.1-8B | 4k | 1.22x faster |
| LLaMA 3.1-8B | 16k | 2.04x faster |
| LLaMA 3.1-70B | 4k | 0.99x (parity) |
| LLaMA 3.1-70B | 16k | 1.66x faster |
The speedup scales with context length because FlexDecoding’s split-KV parallelization strategy is more effective when the KV cache is large. At short contexts, the overhead of kernel launch roughly matches the compute savings.
For training workloads with sample packing, torchtune reports a 71% throughput improvement using FlexAttention’s document masking compared to padding-based approaches.