Extract Attention Weights from Any Transformer

Every transformer model computes attention weights at each layer and head. These weights tell you how much each token influences every other token during inference. You can pull them out with a single flag in Hugging Face Transformers.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

model_name = "textattack/bert-base-uncased-SST-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
    model_name, output_attentions=True
)

text = "The new policy raises serious concerns about user privacy."
inputs = tokenizer(text, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

# outputs.attentions is a tuple: one tensor per layer
# Each tensor shape: (batch, num_heads, seq_len, seq_len)
attentions = outputs.attentions

print(f"Layers: {len(attentions)}")
print(f"Shape per layer: {attentions[0].shape}")
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
print(f"Tokens: {tokens}")

# Grab attention from the last layer, first head
last_layer_head_0 = attentions[-1][0, 0].numpy()
print(f"\nAttention from last layer, head 0:")
for i, token in enumerate(tokens):
    top_3 = last_layer_head_0[i].argsort()[-3:][::-1]
    top_tokens = [(tokens[j], f"{last_layer_head_0[i][j]:.3f}") for j in top_3]
    print(f"  {token:>15s} attends to: {top_tokens}")

The key parameter is output_attentions=True. Without it, the model discards attention weights after computing them and you get nothing. This works on BERT, GPT-2, RoBERTa, and virtually any Hugging Face model.

Visualize Attention Heads with BertViz

Raw attention tensors are hard to interpret as numbers. BertViz turns them into interactive HTML visualizations where you can click on tokens and see attention flows across heads and layers. It is by far the best tool for this job.

1
pip install bertviz transformers torch
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
from bertviz import model_view, head_view
from transformers import AutoTokenizer, AutoModel

model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)

text = "The model flagged the transaction as potentially fraudulent."
inputs = tokenizer(text, return_tensors="pt")
outputs = model(**inputs)

tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
attention = outputs.attentions  # tuple of tensors

# Model view: all layers, all heads at once
model_view(attention, tokens)

# Head view: detailed look at a specific layer
head_view(attention, tokens, layer=11)

Run this in a Jupyter notebook. model_view renders a compact grid of every head across every layer. You can spot patterns immediately: some heads handle positional relationships, others focus on syntactic dependencies, and a few specialize in rare token patterns.

head_view zooms into a single layer and draws colored lines between tokens weighted by attention strength. This is what you show stakeholders. When someone asks “why did the model flag this transaction?” you point to the head that lights up between “flagged” and “fraudulent” with a thick attention line.

My recommendation: always start with model_view to get the bird’s-eye picture, then drill into specific layers with head_view. Layer 0 heads tend to capture local context (adjacent tokens). The final layers capture long-range semantic relationships, which is usually what you care about for explainability.

Token Attribution with Captum and Integrated Gradients

Attention weights show what the model looked at, but they don’t directly tell you which tokens caused a specific prediction. For that, you need attribution methods. Captum’s Integrated Gradients is the gold standard because it satisfies mathematical completeness – the attributions sum to the difference between the model’s output and a baseline.

1
pip install captum transformers torch
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from captum.attr import LayerIntegratedGradients, visualization

model_name = "textattack/bert-base-uncased-SST-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()

text = "This product is dangerously misleading and harms consumers."
inputs = tokenizer(text, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]

# Define a forward function that takes embeddings and returns the predicted class score
def forward_func(input_embeds, attention_mask):
    outputs = model(inputs_embeds=input_embeds, attention_mask=attention_mask)
    # Return the score for the predicted class
    return outputs.logits.max(dim=1).values

# Create the attribution object targeting the embedding layer
lig = LayerIntegratedGradients(forward_func, model.bert.embeddings)

# Compute attributions
# The baseline is a sequence of [PAD] tokens (zeros in embedding space)
baseline_ids = torch.zeros_like(input_ids)
baseline_embeds = model.bert.embeddings(baseline_ids)
input_embeds = model.bert.embeddings(input_ids)

attributions, delta = lig.attribute(
    inputs=input_embeds,
    baselines=baseline_embeds,
    additional_forward_args=(attention_mask,),
    n_steps=50,
    return_convergence_delta=True,
)

# Summarize attributions per token (sum across embedding dimensions)
attr_sum = attributions.sum(dim=-1).squeeze(0)
attr_sum = attr_sum / torch.norm(attr_sum)

tokens = tokenizer.convert_ids_to_tokens(input_ids[0])

print("Token attributions (normalized):")
for token, score in zip(tokens, attr_sum.tolist()):
    bar = "+" * int(abs(score) * 40) if score > 0 else "-" * int(abs(score) * 40)
    print(f"  {token:>20s}  {score:+.4f}  {bar}")

print(f"\nConvergence delta: {delta.item():.6f}")

A convergence delta close to zero means the attributions are reliable. If it’s large (above 0.05), increase n_steps to 200 or 300. The computation is slower but the attributions become more trustworthy.

Integrated Gradients beats simpler methods like raw gradient saliency because it doesn’t suffer from gradient saturation. Vanilla gradients can give you near-zero attributions for highly important tokens if the model is confident. Integrated Gradients avoids this by integrating along a path from the baseline to the input.

Interpreting Multi-Head Attention Patterns

Not all attention heads do the same thing. Research from Clark et al. (2019) showed that BERT heads develop specialized roles. Understanding these patterns helps you trust (or distrust) model behavior.

Here’s how to profile what each head does:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel

model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)

text = "The engineer who designed the bridge spoke at the conference."
inputs = tokenizer(text, return_tensors="pt")
outputs = model(**inputs)

tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
attentions = outputs.attentions

def classify_head(attn_matrix, tokens):
    """Classify an attention head's behavior pattern."""
    seq_len = attn_matrix.shape[0]
    patterns = {}

    # Check for diagonal pattern (attend to self)
    diag_score = np.trace(attn_matrix) / seq_len
    patterns["self_attention"] = diag_score

    # Check for previous-token pattern (attend to token at position i-1)
    prev_score = np.mean([attn_matrix[i, i - 1] for i in range(1, seq_len)])
    patterns["previous_token"] = prev_score

    # Check for [CLS] attention (everything attends to [CLS])
    cls_score = np.mean(attn_matrix[1:, 0])
    patterns["cls_focus"] = cls_score

    # Check for [SEP] attention
    sep_idx = tokens.index("[SEP]") if "[SEP]" in tokens else -1
    if sep_idx > 0:
        sep_score = np.mean(attn_matrix[:, sep_idx])
        patterns["sep_focus"] = sep_score

    dominant = max(patterns, key=patterns.get)
    return dominant, patterns

print("Head behavior classification:")
for layer_idx in [0, 5, 11]:  # Early, middle, late layers
    print(f"\n  Layer {layer_idx}:")
    for head_idx in range(12):
        attn = attentions[layer_idx][0, head_idx].detach().numpy()
        pattern, scores = classify_head(attn, tokens)
        top_score = scores[pattern]
        print(f"    Head {head_idx:2d}: {pattern:<20s} (score: {top_score:.3f})")

You’ll notice early layers (0-3) are dominated by positional patterns: self-attention and previous-token heads. Middle layers start showing syntactic patterns – heads that connect subjects to their verbs across relative clauses. Late layers (10-11) develop semantic heads that attend to task-relevant tokens.

This matters for explainability. When a stakeholder asks why the model made a decision, you should focus on late-layer heads. Early-layer positional heads are just preprocessing – they don’t carry interpretive value.

Building Explainability Reports for Stakeholders

Technical attention maps mean nothing to a product manager or compliance officer. You need to translate model internals into a format they can act on. Here’s a function that generates a plain-language explainability summary:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

def explain_prediction(text, model_name="textattack/bert-base-uncased-SST-2", top_k=5):
    """Generate a stakeholder-friendly explanation for a model prediction."""
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name, output_attentions=True
    )
    model.eval()

    inputs = tokenizer(text, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)

    # Get prediction
    probs = torch.softmax(outputs.logits, dim=1)[0]
    pred_class = probs.argmax().item()
    confidence = probs[pred_class].item()
    label = model.config.id2label.get(pred_class, str(pred_class))

    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

    # Average attention across all heads in the last layer
    last_layer_attn = outputs.attentions[-1][0].mean(dim=0)  # (seq_len, seq_len)
    # Sum how much attention each token receives from all other tokens
    token_importance = last_layer_attn.sum(dim=0).numpy()

    # Filter out special tokens
    token_scores = []
    for i, (token, score) in enumerate(zip(tokens, token_importance)):
        if token not in ("[CLS]", "[SEP]", "[PAD]"):
            token_scores.append((token, float(score)))

    token_scores.sort(key=lambda x: x[1], reverse=True)
    top_tokens = token_scores[:top_k]

    report = []
    report.append(f"Input: \"{text}\"")
    report.append(f"Prediction: {label} (confidence: {confidence:.1%})")
    report.append(f"Top {top_k} influential tokens:")
    for token, score in top_tokens:
        report.append(f"  - \"{token}\" (importance: {score:.3f})")

    return "\n".join(report)


# Generate reports for review
texts = [
    "The loan application was denied due to insufficient credit history.",
    "This candidate has excellent qualifications and strong references.",
    "The automated system flagged the account for suspicious activity.",
]

for text in texts:
    print(explain_prediction(text))
    print("-" * 60)

This approach averages attention across heads in the final layer and ranks tokens by how much total attention they receive. It is a simplification – attention is not the same as causation – but it gives stakeholders something concrete to review. When the top tokens align with reasonable decision factors, that builds trust. When they don’t (e.g., the model focuses on a name or demographic term), that’s a red flag worth investigating.

My strong opinion: never present attention visualizations alone as “proof” that a model is fair or correct. Attention shows correlation, not causation. Always pair attention analysis with attribution methods like Integrated Gradients. If both the attention map and the attribution scores point to the same tokens, you have a stronger case. If they disagree, dig deeper.

Common Errors

RuntimeError: output_attentions is not supported for this model – Some model architectures don’t return attention weights. Check if your model actually supports it:

1
2
3
4
# Verify the model config
from transformers import AutoConfig
config = AutoConfig.from_pretrained("your-model-name")
print(config.output_attentions)  # Should not raise an error

If the model doesn’t support attention outputs natively, switch to Captum’s attribution methods instead. They work on any differentiable model.

ImportError: No module named 'bertviz' – BertViz is not bundled with Transformers. Install it separately:

1
pip install bertviz

Also make sure you’re running in a Jupyter environment. BertViz renders HTML widgets that won’t display in a plain Python script.

ValueError: too many dimensions 'str' when calling model_view – This happens when you pass raw strings instead of tokenized outputs. BertViz expects a tuple of attention tensors and a list of token strings:

1
2
3
4
5
6
# Wrong: passing the raw text
model_view(attention, text)

# Right: passing tokenized token list
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
model_view(attention, tokens)

CUDA out of memory with Integrated Gradients – IG computes gradients across n_steps interpolation points, which multiplies memory usage. Reduce n_steps or move to CPU for smaller models:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
# Option 1: reduce steps (less precise but fits in memory)
attributions, delta = lig.attribute(
    inputs=input_embeds,
    baselines=baseline_embeds,
    additional_forward_args=(attention_mask,),
    n_steps=20,  # down from 50
    return_convergence_delta=True,
)

# Option 2: run on CPU
model = model.cpu()
input_embeds = input_embeds.cpu()
baseline_embeds = baseline_embeds.cpu()
attention_mask = attention_mask.cpu()

AttributeError: 'GPT2Model' object has no attribute 'embeddings' – Different model architectures name their embedding layers differently. For GPT-2, use model.transformer.wte. For RoBERTa, use model.roberta.embeddings:

1
2
3
4
5
6
7
8
# BERT
lig = LayerIntegratedGradients(forward_func, model.bert.embeddings)

# GPT-2
lig = LayerIntegratedGradients(forward_func, model.transformer.wte)

# RoBERTa
lig = LayerIntegratedGradients(forward_func, model.roberta.embeddings)

Check your model’s architecture with print(model) to find the right embedding layer name.