PyTorch’s FlexAttention API lets you write custom attention patterns in plain Python and have torch.compile fuse them into a single optimized Triton kernel. No hand-written CUDA. No messing with FlashAttention forks. You define a score_mod or mask_mod function, and PyTorch generates a kernel that runs within 10% of FlashAttention 2 on Ampere GPUs.

Since PyTorch 2.5, FlexAttention has been available for training. The real news is the FlexDecoding backend – optimized specifically for inference workloads where you have a single query token attending to a long KV cache. In gpt-fast benchmarks, FlexDecoding delivers 1.22x to 2.04x speedup over SDPA on LLaMA 3.1-8B, and up to 1.66x on the 70B model at 16k context length.

Quick Start: Causal Decoding in 20 Lines

Here is the minimal setup for autoregressive decoding with FlexAttention. This works on PyTorch 2.5+ (the current stable release is 2.10).

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask

# Compile once -- reused across prefill and decode phases
flex_attention_compiled = torch.compile(flex_attention)

B, H, D = 4, 32, 64           # batch, heads, head_dim
MAX_SEQ_LEN = 16384

# Pre-allocate KV cache
k_cache = torch.zeros(B, H, MAX_SEQ_LEN, D, device="cuda", dtype=torch.bfloat16)
v_cache = torch.zeros(B, H, MAX_SEQ_LEN, D, device="cuda", dtype=torch.bfloat16)

# Offset tensor -- must be a CUDA tensor, not a Python int.
# Using a Python int causes recompilation every step.
offset = torch.tensor(0, device="cuda", dtype=torch.long)

def causal_mask(score, b, h, q_idx, kv_idx):
    return torch.where(q_idx + offset >= kv_idx, score, -float("inf"))

# Prefill: process the full prompt at once
prompt_len = 512
q_prefill = torch.randn(B, H, prompt_len, D, device="cuda", dtype=torch.bfloat16)
k_cache[:, :, :prompt_len, :] = torch.randn(B, H, prompt_len, D, device="cuda", dtype=torch.bfloat16)
v_cache[:, :, :prompt_len, :] = torch.randn(B, H, prompt_len, D, device="cuda", dtype=torch.bfloat16)

out = flex_attention_compiled(q_prefill, k_cache[:, :, :prompt_len], v_cache[:, :, :prompt_len], score_mod=causal_mask)

# Decode: one token at a time
for step in range(100):
    pos = prompt_len + step
    offset.fill_(pos)

    q_token = torch.randn(B, H, 1, D, device="cuda", dtype=torch.bfloat16)
    # ... update k_cache, v_cache at position `pos` ...

    out = flex_attention_compiled(
        q_token,
        k_cache[:, :, :pos + 1],
        v_cache[:, :, :pos + 1],
        score_mod=causal_mask,
    )

PyTorch automatically selects the FlexDecoding backend when it detects a short query (1-2 tokens) against a long KV sequence. No flag to flip – it just works.

The score_mod and mask_mod APIs

FlexAttention exposes two hooks into the attention computation:

  • score_mod(score, b, h, q_idx, kv_idx) -> Tensor – Modifies the raw dot-product score before softmax. The score argument is a scalar tensor. Return the modified score.
  • mask_mod(b, h, q_idx, kv_idx) -> bool – Returns True if position (q_idx, kv_idx) should attend. Positions returning False get masked out.

You can use either one, or both. mask_mod is more efficient when your pattern is purely boolean (causal, sliding window, block-sparse) because PyTorch can build a BlockMask and skip entire 128x128 blocks of computation.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
from torch.nn.attention.flex_attention import create_block_mask

# Sliding window attention: each token attends to the last 1024 positions
WINDOW = 1024

def sliding_window(b, h, q_idx, kv_idx):
    return (q_idx - kv_idx >= 0) & (q_idx - kv_idx < WINDOW)

block_mask = create_block_mask(
    sliding_window,
    B=None,               # None = broadcast across batch
    H=None,               # None = broadcast across heads
    Q_LEN=MAX_SEQ_LEN,
    KV_LEN=MAX_SEQ_LEN,
    device="cuda",
)

# Use block_mask instead of score_mod
out = flex_attention_compiled(q, k, v, block_mask=block_mask)

The BlockMask is precomputed once and reused. During decoding, you slice it to match your current sequence position. The default block size is 128 – do not change it unless you have a specific reason, since non-128 block sizes can trigger compilation failures (see Troubleshooting below).

Grouped Query Attention (GQA)

Most production models (LLaMA 3, Mistral, Gemma 2) use GQA where fewer KV heads are shared across query heads. FlexAttention handles this natively with enable_gqa=True:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
Q_HEADS, KV_HEADS = 32, 8  # LLaMA 3 style

q = torch.randn(B, Q_HEADS, 1, D, device="cuda", dtype=torch.bfloat16)
k = torch.randn(B, KV_HEADS, seq_len, D, device="cuda", dtype=torch.bfloat16)
v = torch.randn(B, KV_HEADS, seq_len, D, device="cuda", dtype=torch.bfloat16)

out = flex_attention_compiled(
    q, k, v,
    score_mod=causal_mask,
    enable_gqa=True,  # broadcasts KV heads to match Q heads
)

Without enable_gqa=True, you would need to manually repeat_interleave the KV tensors, which wastes memory. The flag tells FlexAttention to handle the broadcasting internally in the fused kernel.

Offset Wrapping for Decoding Loops

A common pattern is wrapping an existing score_mod or mask_mod to add positional offset support. This keeps your attention logic clean and reusable:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def make_offset_score_mod(base_score_mod, offset_tensor):
    """Wrap a score_mod to shift q_idx by a runtime offset."""
    def _score_mod(score, b, h, q_idx, kv_idx):
        return base_score_mod(score, b, h, q_idx + offset_tensor, kv_idx)
    return _score_mod

def make_offset_mask_mod(base_mask_mod, offset_tensor):
    """Wrap a mask_mod to shift q_idx by a runtime offset."""
    def _mask_mod(b, h, q_idx, kv_idx):
        return base_mask_mod(b, h, q_idx + offset_tensor, kv_idx)
    return _mask_mod

# Usage
offset = torch.tensor(0, device="cuda", dtype=torch.long)

def causal(score, b, h, q_idx, kv_idx):
    return torch.where(q_idx >= kv_idx, score, -float("inf"))

causal_with_offset = make_offset_score_mod(causal, offset)
flex_attn = torch.compile(flex_attention)

for step in range(num_decode_steps):
    offset.fill_(current_position)
    out = flex_attn(q_token, k_cache, v_cache, score_mod=causal_with_offset)

The critical detail: offset must be a CUDA tensor updated with .fill_(). If you pass a Python integer, torch.compile treats it as a constant and recompiles the kernel every single step. That turns a 2ms decode step into a 30-second recompilation.

Troubleshooting

“BackendCompilerFailed: NoValidChoicesError” with BlockMask

This happens when you set a non-default BLOCK_SIZE in create_block_mask. Stick with the default 128:

1
2
3
4
5
6
7
# This will crash:
block_mask = create_block_mask(mask_fn, B=1, H=1, Q_LEN=4096, KV_LEN=4096, BLOCK_SIZE=64)
# BackendCompilerFailed: NoValidChoicesError: No choices to select

# This works:
block_mask = create_block_mask(mask_fn, B=1, H=1, Q_LEN=4096, KV_LEN=4096)
# Uses default BLOCK_SIZE=128

Precision Drift with torch.set_float32_matmul_precision(‘high’)

When you combine torch.compile + FlexAttention + torch.set_float32_matmul_precision('high'), outputs can diverge from SDPA by up to 11% relative difference. This is because the TF32 precision path interacts poorly with the fused kernel.

Fix: Use 'highest' precision, or just run in bfloat16 (which sidesteps the float32 matmul path entirely):

1
2
3
4
5
6
7
# Option A: force highest precision
torch.set_float32_matmul_precision('highest')

# Option B (recommended): use bfloat16 throughout
q = q.to(torch.bfloat16)
k = k.to(torch.bfloat16)
v = v.to(torch.bfloat16)

Segfault or “LLVM ERROR: Failed to compute parent layout” on Ampere GPUs

Mixed precision (torch.autocast) combined with compiled FlexAttention can crash on some Ampere GPUs (RTX 3090, RTX 3080). The error looks like:

1
2
LLVM ERROR: Failed to compute parent layout for slice layout.
Aborted (core dumped)

Or simply a segmentation fault with no useful traceback. Workarounds:

  1. Cast your tensors to bfloat16 explicitly instead of using torch.autocast.
  2. Run without torch.compile (you lose the kernel fusion, but it won’t crash).
  3. Update to PyTorch nightly – several of these Triton codegen bugs have been patched in recent builds.

No Backward Support on CPU

FlexAttention only supports backward passes on CUDA. If you try to compute gradients on CPU, you get:

1
RuntimeError: FlexAttention does not support backward on CPU

This is by design – the fused Triton kernel is GPU-only. For CPU inference (forward pass only), the eager fallback works fine without compilation.

When to Use FlexAttention vs. FlashAttention

Use FlexAttention when:

  • You need custom attention patterns (sliding window, document masking, ALiBi, soft-capping) and do not want to maintain custom CUDA kernels.
  • You are already using torch.compile in your inference pipeline.
  • You want GQA + PagedAttention without pulling in a separate library.

Stick with FlashAttention / SDPA when:

  • You only need standard causal or full attention with no modifications. SDPA already dispatches to FlashAttention under the hood.
  • You are on hardware without Triton support (older GPUs, non-NVIDIA).
  • You need maximum absolute performance on Hopper – FlashAttention 3 is still about 25% faster than FlexAttention on H100.

Performance Numbers to Expect

These benchmarks come from the PyTorch team’s gpt-fast evaluations with LLaMA 3.1 models on NVIDIA GPUs:

ModelContext LengthFlexDecoding vs SDPA
LLaMA 3.1-8B4k1.22x faster
LLaMA 3.1-8B16k2.04x faster
LLaMA 3.1-70B4k0.99x (parity)
LLaMA 3.1-70B16k1.66x faster

The speedup scales with context length because FlexDecoding’s split-KV parallelization strategy is more effective when the KV cache is large. At short contexts, the overhead of kernel launch roughly matches the compute savings.

For training workloads with sample packing, torchtune reports a 71% throughput improvement using FlexAttention’s document masking compared to padding-based approaches.