The Problem: You Can’t Just Delete the Row#
Someone submits a GDPR data removal request. You delete their records from the database. Done, right? Not quite. If their data was used to train a model, that model still carries traces of it. Weight updates during training baked their data into the parameters. Regulators increasingly expect you to address this, and “we’d have to retrain from scratch” isn’t a great answer when retraining costs $50k and takes a week.
Machine unlearning gives you a practical middle ground. You modify the trained model so it behaves as if specific data points were never in the training set. The two main approaches: gradient ascent on the forget set (push the model away from that data), then fine-tuning on the retain set (recover performance on everything else).
Install what you need:
1
| pip install torch torchvision scikit-learn numpy
|
Here’s a quick proof of concept. We’ll train a model, then unlearn a subset of the training data:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
| import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
# Simple classifier
class SmallNet(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(28 * 28, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
return self.fc3(x)
# Load MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
full_train = datasets.MNIST("./data", train=True, download=True, transform=transform)
test_set = datasets.MNIST("./data", train=False, download=True, transform=transform)
# Split: 500 samples to forget, rest to retain
forget_indices = list(range(500))
retain_indices = list(range(500, len(full_train)))
forget_set = Subset(full_train, forget_indices)
retain_set = Subset(full_train, retain_indices)
forget_loader = DataLoader(forget_set, batch_size=64, shuffle=True)
retain_loader = DataLoader(retain_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=256, shuffle=False)
# Train the original model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SmallNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
for epoch in range(5):
model.train()
for images, labels in DataLoader(full_train, batch_size=64, shuffle=True):
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
loss = criterion(model(images), labels)
loss.backward()
optimizer.step()
print("Original model trained.")
|
That gives you a baseline. The model has seen every sample, including the 500 we want it to forget.
Gradient Ascent Unlearning#
The idea is simple and a bit counterintuitive: run training in reverse on the forget set. Instead of minimizing the loss, you maximize it. This pushes the model’s parameters away from configurations that fit the forget data well.
You do this by negating the loss. Standard training computes loss.backward() and steps toward lower loss. Here, you compute (-loss).backward() so the gradient points uphill.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
| def gradient_ascent_unlearn(model, forget_loader, epochs=5, lr=1e-4):
"""Push the model away from the forget set by maximizing loss."""
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
model.train()
for epoch in range(epochs):
total_loss = 0.0
for images, labels in forget_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
# Negate the loss -- gradient ascent
(-loss).backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(forget_loader)
print(f"Unlearn epoch {epoch+1}: avg loss on forget set = {avg_loss:.4f}")
return model
# Run gradient ascent unlearning
model = gradient_ascent_unlearn(model, forget_loader, epochs=5, lr=1e-4)
|
A few things matter here. The learning rate should be small – typically 5x to 10x lower than what you used for original training. Too aggressive and you’ll destroy the model entirely. Five epochs is usually enough for a small forget set. Watch the loss on the forget set: it should climb steadily.
Fine-Tuning on the Retain Set#
Gradient ascent is destructive by design. It doesn’t just forget the target data; it damages nearby decision boundaries too. Your test accuracy will drop. The fix: fine-tune on the retain set to recover what you broke.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
| def finetune_retain(model, retain_loader, epochs=3, lr=5e-5):
"""Restore model performance on data that should be remembered."""
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
model.train()
for epoch in range(epochs):
total_loss = 0.0
correct = 0
total = 0
for images, labels in retain_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
correct += (outputs.argmax(dim=1) == labels).sum().item()
total += labels.size(0)
acc = 100.0 * correct / total
print(f"Retain epoch {epoch+1}: loss={total_loss/len(retain_loader):.4f}, acc={acc:.1f}%")
return model
model = finetune_retain(model, retain_loader, epochs=3, lr=5e-5)
|
Use an even lower learning rate here. You want gentle correction, not another round of heavy training. Three to five epochs typically suffices. If accuracy on the retain set doesn’t come back to within 1-2% of the original, try lowering the gradient ascent learning rate or reducing the number of unlearning epochs.
Verifying Unlearning with Membership Inference#
You’ve run the pipeline, but how do you prove the model actually forgot? The standard approach: a membership inference test. If the model still recognizes the forget set as “seen before,” unlearning failed.
The core signal is confidence. A model is typically more confident on data it was trained on. After unlearning, the forget set should look no different from unseen test data.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
| def membership_inference_audit(model, forget_loader, test_loader):
"""Check if the model still distinguishes forget data from unseen data."""
model.eval()
def get_confidence(loader, max_batches=20):
confidences = []
with torch.no_grad():
for i, (images, labels) in enumerate(loader):
if i >= max_batches:
break
images = images.to(device)
probs = torch.softmax(model(images), dim=1)
max_conf = probs.max(dim=1).values
confidences.extend(max_conf.cpu().numpy())
return np.array(confidences)
forget_conf = get_confidence(forget_loader)
test_conf = get_confidence(test_loader)
print(f"Forget set mean confidence: {forget_conf.mean():.4f} (std: {forget_conf.std():.4f})")
print(f"Test set mean confidence: {test_conf.mean():.4f} (std: {test_conf.std():.4f})")
gap = abs(forget_conf.mean() - test_conf.mean())
print(f"Confidence gap: {gap:.4f}")
if gap < 0.02:
print("PASS: Model treats forget set similarly to unseen data.")
else:
print("WARN: Confidence gap suggests residual memorization.")
return gap
gap = membership_inference_audit(model, forget_loader, test_loader)
|
A confidence gap under 0.02 is a good sign. The model can’t distinguish forgotten samples from data it never saw. If the gap is still large, run more gradient ascent epochs or increase the learning rate slightly.
For a more rigorous test, train a binary classifier on the confidence vectors to predict membership. If that classifier does no better than 50% accuracy on forget-set vs. test-set samples, unlearning is solid.
End-to-End Unlearning Pipeline#
Here’s the complete pipeline wrapped in a reusable class. This is what you’d integrate into your data removal workflow:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
| import json
from datetime import datetime
class UnlearningPipeline:
def __init__(self, model, device, forget_lr=1e-4, retain_lr=5e-5):
self.model = model
self.device = device
self.forget_lr = forget_lr
self.retain_lr = retain_lr
self.criterion = nn.CrossEntropyLoss()
self.audit_log = []
def _evaluate(self, loader):
self.model.eval()
correct, total = 0, 0
with torch.no_grad():
for images, labels in loader:
images, labels = images.to(self.device), labels.to(self.device)
preds = self.model(images).argmax(dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
return 100.0 * correct / total
def unlearn(self, forget_loader, retain_loader, test_loader,
unlearn_epochs=5, retain_epochs=3, request_id="unknown"):
"""Full unlearning cycle: gradient ascent, fine-tune, verify."""
# Snapshot pre-unlearning metrics
pre_test_acc = self._evaluate(test_loader)
pre_forget_acc = self._evaluate(forget_loader)
print(f"[{request_id}] Pre-unlearning -- test acc: {pre_test_acc:.1f}%, "
f"forget acc: {pre_forget_acc:.1f}%")
# Step 1: Gradient ascent on forget set
optimizer = optim.Adam(self.model.parameters(), lr=self.forget_lr)
self.model.train()
for epoch in range(unlearn_epochs):
for images, labels in forget_loader:
images, labels = images.to(self.device), labels.to(self.device)
optimizer.zero_grad()
loss = self.criterion(self.model(images), labels)
(-loss).backward()
optimizer.step()
# Step 2: Fine-tune on retain set
optimizer = optim.Adam(self.model.parameters(), lr=self.retain_lr)
for epoch in range(retain_epochs):
self.model.train()
for images, labels in retain_loader:
images, labels = images.to(self.device), labels.to(self.device)
optimizer.zero_grad()
loss = self.criterion(self.model(images), labels)
loss.backward()
optimizer.step()
# Step 3: Verify
post_test_acc = self._evaluate(test_loader)
post_forget_acc = self._evaluate(forget_loader)
print(f"[{request_id}] Post-unlearning -- test acc: {post_test_acc:.1f}%, "
f"forget acc: {post_forget_acc:.1f}%")
record = {
"request_id": request_id,
"timestamp": datetime.now().isoformat(),
"pre_test_acc": pre_test_acc,
"post_test_acc": post_test_acc,
"pre_forget_acc": pre_forget_acc,
"post_forget_acc": post_forget_acc,
"accuracy_drop": pre_test_acc - post_test_acc,
"forget_samples": len(forget_loader.dataset),
}
self.audit_log.append(record)
return record
def save_audit_log(self, path="unlearning_audit.json"):
with open(path, "w") as f:
json.dump(self.audit_log, f, indent=2)
print(f"Audit log saved to {path}")
# Usage
pipeline = UnlearningPipeline(model, device)
result = pipeline.unlearn(
forget_loader, retain_loader, test_loader,
unlearn_epochs=5, retain_epochs=3,
request_id="GDPR-2026-0042"
)
pipeline.save_audit_log()
|
The audit log is important. When a regulator asks “how did you handle this deletion request?”, you can point to a timestamped record showing accuracy on the forget set dropped, test accuracy stayed reasonable, and the membership inference gap closed. Keep these logs alongside your data processing records.
Tuning the Unlearning Hyperparameters#
Getting the balance right between forgetting and retaining is the hard part. Here’s what works in practice:
- Forget learning rate: Start at 10% of your original training LR. If the model doesn’t forget enough, bump it up. If test accuracy craters, pull it back.
- Forget epochs: 3-7 for small forget sets (under 1% of training data). More epochs for larger forget sets.
- Retain learning rate: Half the forget LR or less. You want minimal parameter movement here.
- Retain epochs: 2-5 usually does it. Stop when test accuracy plateaus.
For models trained on sensitive data where the stakes are high, run the full membership inference audit after every unlearning request. For lower-risk models, a quick confidence gap check is enough.
Common Errors and Fixes#
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
This happens when your model uses inplace ReLU (nn.ReLU(inplace=True)). The negated loss backward pass doesn’t play well with inplace ops. Fix: switch to nn.ReLU() without the inplace flag.
Test accuracy drops to near-random after gradient ascent
Your unlearning learning rate is too high or you ran too many epochs. The gradient ascent step scrambled the weights beyond recovery. Reduce forget_lr by half and cut epochs to 2-3. Also check that your forget set isn’t disproportionately large – unlearning 30% of the training set will always be destructive.
Membership inference gap stays high after unlearning
The model still memorizes the forget data. Try: (1) increase gradient ascent epochs by 2-3, (2) slightly raise the forget LR, (3) add noise to the model weights after unlearning with param.data += torch.randn_like(param) * 1e-4. Sometimes the loss landscape has sharp minima that a few gradient steps won’t escape.
CUDA out of memory during gradient ascent
Gradient ascent uses the same memory as normal training. If your original training used gradient accumulation, use it here too. Reduce the batch size of forget_loader or use torch.cuda.amp.autocast() for mixed precision.
Forget set accuracy doesn’t drop at all
Double-check that you’re negating the loss, not the gradients. The correct pattern is (-loss).backward(), not loss.backward() followed by manually flipping gradient signs. The latter can break with optimizers that maintain running statistics (Adam’s momentum and variance).
Audit log shows accuracy drop greater than 5%
This means the retain fine-tuning phase isn’t doing enough. Increase retain_epochs or bump the retain learning rate. Another option: instead of fine-tuning on the full retain set, fine-tune only on data that’s semantically similar to the forget set. This focuses recovery where it matters most.