Prefix tuning prepends a sequence of trainable continuous vectors – called “virtual tokens” – to the input at every transformer layer. The base model weights stay frozen. Only the prefix parameters update during training, which means you’re training less than 0.1% of the total parameters. The result: 90%+ memory savings compared to full fine-tuning, and you can store dozens of task-specific adapters as tiny checkpoint files alongside a single base model.

Unlike LoRA, which modifies weight matrices with low-rank decompositions, prefix tuning operates in the activation space. It learns task-specific context that steers the model’s attention patterns without touching any existing parameters. This makes it particularly effective for generation tasks where you want to condition the model on a specific style or domain.

Setting Up PEFT for Prefix Tuning

Install the required packages:

1
pip install peft transformers datasets torch accelerate

Now configure PrefixTuningConfig and wrap a model. We’ll use GPT-2 here because it runs on any GPU, but this works the same way with Llama, Mistral, or any causal LM in the Hugging Face ecosystem.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
from peft import PrefixTuningConfig, get_peft_model, TaskType
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(model_name)

prefix_config = PrefixTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    num_virtual_tokens=20,
    encoder_hidden_size=768,  # Must match the model's hidden size (768 for gpt2)
    prefix_projection=True,   # Use a 2-layer MLP to project prefix embeddings
)

model = get_peft_model(model, prefix_config)
model.print_trainable_parameters()
# Output: trainable params: 983,040 || all params: 125,424,384 || trainable%: 0.7837

Key parameters in PrefixTuningConfig:

  • num_virtual_tokens: The number of prefix tokens prepended at each layer. 20 is a solid starting point. More tokens give the model more capacity to steer behavior, but beyond 50 you hit diminishing returns and slower inference.
  • encoder_hidden_size: Set this to the model’s hidden dimension. For GPT-2 it’s 768, for Llama-2-7B it’s 4096.
  • prefix_projection: When True, PEFT uses a 2-layer MLP to generate the prefix embeddings instead of directly optimizing them. This stabilizes training significantly – always leave it on.

Training a Prefix-Tuned Model

We’ll fine-tune on a subset of the dair-ai/emotion dataset for text generation conditioned on emotion labels. This is a real dataset with 6 emotion categories.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import default_data_collator, get_linear_schedule_with_warmup

# Load and preprocess dataset
dataset = load_dataset("dair-ai/emotion", split="train[:2000]")
label_map = {0: "sadness", 1: "joy", 2: "love", 3: "anger", 4: "fear", 5: "surprise"}

def preprocess(example):
    label_text = label_map[example["label"]]
    prompt = f"Emotion: {label_text}\nText: {example['text']}{tokenizer.eos_token}"
    tokenized = tokenizer(
        prompt,
        truncation=True,
        max_length=128,
        padding="max_length",
    )
    tokenized["labels"] = tokenized["input_ids"].copy()
    return tokenized

tokenized_dataset = dataset.map(preprocess, remove_columns=dataset.column_names)
tokenized_dataset.set_format("torch")

train_dataloader = DataLoader(
    tokenized_dataset,
    shuffle=True,
    batch_size=8,
    collate_fn=default_data_collator,
)

# Training setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=100,
    num_training_steps=num_training_steps,
)

# Training loop
model.train()
for epoch in range(num_epochs):
    total_loss = 0
    for step, batch in enumerate(train_dataloader):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        total_loss += loss.detach().float().item()

        if step % 50 == 0:
            avg_loss = total_loss / (step + 1)
            print(f"Epoch {epoch} | Step {step} | Loss: {avg_loss:.4f}")

    avg_epoch_loss = total_loss / len(train_dataloader)
    print(f"Epoch {epoch} complete | Avg Loss: {avg_epoch_loss:.4f}")

# Save the prefix adapter
model.save_pretrained("prefix-emotion-adapter")
tokenizer.save_pretrained("prefix-emotion-adapter")

The saved adapter is tiny – typically under 5MB. The entire base model stays untouched.

A few training notes worth calling out. The learning rate of 3e-4 is higher than what you’d use for LoRA (2e-4). Prefix tuning benefits from slightly more aggressive learning rates because you’re optimizing far fewer parameters. If you see the loss spike or oscillate, drop it to 1e-4.

Running Inference with a Prefix-Tuned Model

Load the saved adapter and generate text:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

base_model = AutoModelForCausalLM.from_pretrained("gpt2")
model = PeftModel.from_pretrained(base_model, "prefix-emotion-adapter")
tokenizer = AutoTokenizer.from_pretrained("prefix-emotion-adapter")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

prompt = "Emotion: joy\nText:"
inputs = tokenizer(prompt, return_tensors="pt").to(device)

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=50,
        temperature=0.7,
        do_sample=True,
        top_p=0.9,
    )

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(generated_text)

You can swap adapters at runtime without reloading the base model. This is one of the big practical wins of parameter-efficient methods – serve one base model, keep a library of adapters for different tasks, and load the right one per request.

1
2
3
# Load a different adapter on the same base model
model.load_adapter("prefix-formality-adapter", adapter_name="formality")
model.set_adapter("formality")

Comparing Prefix Tuning to LoRA

Both are PEFT methods, but they work differently and have distinct sweet spots.

Prefix tuning prepends virtual tokens to the key-value pairs in attention. It’s best when you want to steer generation style or condition on task context without altering the model’s internal representations. It trains fewer parameters (often 10-100x fewer than LoRA) and produces smaller adapter files. The downside: it adds latency proportional to num_virtual_tokens because the model processes those extra tokens at every layer during inference.

LoRA injects trainable low-rank matrices into the attention and MLP weight matrices. It modifies how the model computes representations at each layer. LoRA is more expressive for the same parameter count and generally achieves better results on complex tasks like code generation or reasoning. It also adds zero inference latency because you can merge the adapter weights into the base model.

When to use which:

  • Pick prefix tuning when you need minimal adapter size, have limited GPU memory during training, or want a quick task-specific conditioner for generation.
  • Pick LoRA when you need the best possible quality, are fine-tuning for complex tasks, or want zero-overhead inference after merging.
  • Both beat full fine-tuning on cost and flexibility. You can always start with prefix tuning as a fast experiment and switch to LoRA if you need more capacity.

Common Errors and Fixes

ValueError: encoder_hidden_size must match the model's hidden size

You set encoder_hidden_size to the wrong value. Check the model config:

1
2
3
4
from transformers import AutoConfig

config = AutoConfig.from_pretrained("gpt2")
print(config.hidden_size)  # 768 for gpt2

Use whatever this prints as your encoder_hidden_size.

RuntimeError: CUDA out of memory

Even though prefix tuning is memory-efficient, you can still OOM on large models. Reduce batch size first, then reduce num_virtual_tokens:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
# Drop batch size
train_dataloader = DataLoader(tokenized_dataset, batch_size=2, ...)

# Or reduce virtual tokens from 20 to 10
prefix_config = PrefixTuningConfig(
    task_type=TaskType.CAUSAL_LM,
    num_virtual_tokens=10,
    encoder_hidden_size=768,
    prefix_projection=True,
)

RuntimeError: expected scalar type Half but found Float

Mixed precision mismatch. Either cast the model explicitly or use torch.autocast:

1
2
3
4
5
from torch.cuda.amp import autocast

with autocast(dtype=torch.float16):
    outputs = model(**batch)
    loss = outputs.loss

KeyError: 'past_key_values' during generation

Some older versions of PEFT have bugs with prefix tuning and generate(). Update PEFT:

1
pip install --upgrade peft>=0.13.0

Training loss doesn’t decrease

Check that prefix_projection=True is set. Without the projection MLP, the prefix embeddings are harder to optimize and training can stall, especially on larger models. Also verify your learning rate isn’t too low – prefix tuning can tolerate 3e-4 to 1e-3 depending on the model size.