Set Up a Flower Server and Client in Under 50 Lines#
Federated learning keeps data on the devices that generated it. Instead of shipping data to a central server, you ship the model to the data, train locally, and send only weight updates back. Flower (flwr) is the best framework for this right now – it handles the orchestration so you can focus on your model.
Install the dependencies:
1
| pip install flwr[simulation] torch torchvision
|
Here’s a complete working example. First, define your model and a Flower client that trains it locally:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
| import flwr as fl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from collections import OrderedDict
# Simple CNN -- swap this for whatever architecture you need
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(32 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
x = x.view(-1, 32 * 7 * 7)
x = self.relu(self.fc1(x))
return self.fc2(x)
def get_parameters(model):
return [val.cpu().numpy() for val in model.state_dict().values()]
def set_parameters(model, parameters):
params = zip(model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params})
model.load_state_dict(state_dict, strict=True)
class FlowerClient(fl.client.NumPyClient):
def __init__(self, model, trainloader, testloader):
self.model = model
self.trainloader = trainloader
self.testloader = testloader
def get_parameters(self, config):
return get_parameters(self.model)
def fit(self, parameters, config):
set_parameters(self.model, parameters)
train(self.model, self.trainloader, epochs=1)
return get_parameters(self.model), len(self.trainloader.dataset), {}
def evaluate(self, parameters, config):
set_parameters(self.model, parameters)
loss, accuracy = test(self.model, self.testloader)
return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)}
def train(model, loader, epochs=1):
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
model.train()
for _ in range(epochs):
for images, labels in loader:
optimizer.zero_grad()
criterion(model(images), labels).backward()
optimizer.step()
def test(model, loader):
criterion = nn.CrossEntropyLoss()
model.eval()
loss, correct, total = 0.0, 0, 0
with torch.no_grad():
for images, labels in loader:
outputs = model(images)
loss += criterion(outputs, labels).item()
correct += (outputs.argmax(1) == labels).sum().item()
total += labels.size(0)
return loss / len(loader), correct / total
|
Run the Simulation#
Flower’s simulation engine lets you test federated training without spinning up separate processes. This is the fastest way to iterate:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
| from flwr.simulation import start_simulation
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
full_train = datasets.MNIST("./data", train=True, download=True, transform=transform)
full_test = datasets.MNIST("./data", train=False, transform=transform)
NUM_CLIENTS = 10
# Split data across clients (IID split for now)
partition_size = len(full_train) // NUM_CLIENTS
indices = list(range(len(full_train)))
def client_fn(cid: str) -> fl.client.Client:
idx = int(cid)
start = idx * partition_size
end = start + partition_size
trainset = Subset(full_train, indices[start:end])
# Each client gets a slice of the test set too
test_start = idx * (len(full_test) // NUM_CLIENTS)
test_end = test_start + (len(full_test) // NUM_CLIENTS)
testset = Subset(full_test, list(range(test_start, test_end)))
model = Net()
return FlowerClient(
model,
DataLoader(trainset, batch_size=32, shuffle=True),
DataLoader(testset, batch_size=32),
).to_client()
# FedAvg is the default and best starting point
strategy = fl.server.strategy.FedAvg(
fraction_fit=0.5, # train on 50% of clients per round
fraction_evaluate=0.3, # evaluate on 30%
min_fit_clients=3, # minimum 3 clients per round
min_available_clients=NUM_CLIENTS,
)
start_simulation(
client_fn=client_fn,
num_clients=NUM_CLIENTS,
config=fl.server.ServerConfig(num_rounds=5),
strategy=strategy,
)
|
After 5 rounds with 10 clients, you should see the global model reaching around 97% accuracy on MNIST. Each round, only 5 clients train (50% sampling), and the server averages their weight updates using FedAvg.
Handle Non-IID Data#
Real-world federated data is almost never IID. One hospital might see mostly cardiac cases while another specializes in orthopedics. Simulating this is critical for realistic experiments.
The simplest non-IID split assigns each client only a subset of the label classes:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
| import numpy as np
def create_non_iid_split(dataset, num_clients, classes_per_client=2):
"""Give each client data from only `classes_per_client` classes."""
label_indices = {}
for idx, (_, label) in enumerate(dataset):
label_indices.setdefault(label, []).append(idx)
client_indices = [[] for _ in range(num_clients)]
all_labels = sorted(label_indices.keys())
for i in range(num_clients):
# Assign 2 classes per client, rotating through labels
assigned = [all_labels[j % len(all_labels)]
for j in range(i * classes_per_client,
(i + 1) * classes_per_client)]
for label in assigned:
# Give each client an equal chunk of their assigned classes
idxs = label_indices[label]
chunk_size = len(idxs) // num_clients
start = i * chunk_size
client_indices[i].extend(idxs[start:start + chunk_size])
return client_indices
|
When you switch to non-IID splits, expect the global model to converge slower and to a lower final accuracy. Two strategies help: increase num_rounds (more communication rounds) and increase fraction_fit (more clients per round to get better coverage of the label space).
Add Differential Privacy to Federated Rounds#
Federated learning alone is not enough for strong privacy. Weight updates can still leak information about individual training samples. You need differential privacy on top.
The approach is straightforward: clip the norm of each client’s model update and add Gaussian noise at the server before aggregation. Flower lets you build a custom strategy for this:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
| from flwr.server.strategy import FedAvg
from flwr.common import FitRes, Parameters, Scalar
import numpy as np
from typing import List, Tuple, Optional, Dict, Union
from flwr.server.client_proxy import ClientProxy
from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays
class DPFedAvg(FedAvg):
def __init__(self, noise_multiplier: float = 1.0,
clipping_norm: float = 1.0, **kwargs):
super().__init__(**kwargs)
self.noise_multiplier = noise_multiplier
self.clipping_norm = clipping_norm
def aggregate_fit(
self,
server_round: int,
results: List[Tuple[ClientProxy, FitRes]],
failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
if not results:
return None, {}
# Extract and clip updates
clipped_updates = []
for _, fit_res in results:
weights = parameters_to_ndarrays(fit_res.parameters)
# Compute L2 norm of the full update
flat = np.concatenate([w.flatten() for w in weights])
norm = np.linalg.norm(flat)
# Clip if needed
scale = min(1.0, self.clipping_norm / (norm + 1e-8))
clipped = [w * scale for w in weights]
clipped_updates.append(clipped)
# Average the clipped updates
num_clients = len(clipped_updates)
avg_update = [
np.mean(np.array([u[i] for u in clipped_updates]), axis=0)
for i in range(len(clipped_updates[0]))
]
# Add calibrated Gaussian noise
sensitivity = self.clipping_norm / num_clients
noise_std = sensitivity * self.noise_multiplier
noisy_update = [
w + np.random.normal(0, noise_std, size=w.shape)
for w in avg_update
]
return ndarrays_to_parameters(noisy_update), {}
|
Use it like any other strategy:
1
2
3
4
5
6
7
| dp_strategy = DPFedAvg(
noise_multiplier=1.0,
clipping_norm=1.0,
fraction_fit=0.5,
min_fit_clients=3,
min_available_clients=NUM_CLIENTS,
)
|
The noise_multiplier controls the privacy-utility tradeoff. Higher values give stronger privacy (lower epsilon) but hurt accuracy. Start with 1.0 and tune from there. A clipping_norm of 1.0 works well for most CNNs – increase it if you see the model failing to learn at all.
Evaluate the Global Model#
After each federated round, Flower calls evaluate on the sampled clients. But you also want to evaluate the aggregated global model on a held-out centralized test set. Add an evaluate_fn to your strategy:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
| def get_evaluate_fn(testloader):
def evaluate(server_round, parameters, config):
model = Net()
set_parameters(model, parameters_to_ndarrays(parameters)
if isinstance(parameters, Parameters)
else parameters)
loss, accuracy = test(model, testloader)
print(f"Round {server_round}: loss={loss:.4f}, accuracy={accuracy:.4f}")
return loss, {"accuracy": accuracy}
return evaluate
central_testloader = DataLoader(full_test, batch_size=64)
strategy = fl.server.strategy.FedAvg(
fraction_fit=0.5,
min_fit_clients=3,
min_available_clients=NUM_CLIENTS,
evaluate_fn=get_evaluate_fn(central_testloader),
)
|
This runs after every aggregation round and gives you a clean accuracy curve for the global model, independent of which clients happened to participate.
Common Errors and Fixes#
RuntimeError: Sizes of tensors must match – This happens when your model architecture differs between clients or between the client and server. Every participant must use the exact same model class with the same hyperparameters. Double-check that all clients instantiate the same Net().
grpc._channel._InactiveRpcError – The server isn’t running or the client can’t reach it. In simulation mode you won’t see this, but in a real deployment, make sure the server starts before clients connect. Set server_address="0.0.0.0:8080" and verify firewall rules.
Non-IID training produces random-chance accuracy – If each client only sees 1-2 classes, local models become extremely biased. The fix: increase fraction_fit to 0.8+ so more clients contribute each round, or use FedProx instead of FedAvg (set proximal_mu=0.1 in FedProx strategy) to keep local models closer to the global model.
DP noise destroys model accuracy – You added too much noise relative to the number of clients. Reduce noise_multiplier, increase clipping_norm, or add more clients. The noise scales with clipping_norm / num_clients, so more participants means less noise per parameter.
ValueError: NumPy array is not writable – Some versions of Flower return read-only arrays from parameters_to_ndarrays. Fix it by calling np.copy() on each array before modifying it in place.
Memory issues with many simulated clients – Flower’s simulation creates all client models in the same process. Set ray_init_args={"num_cpus": 4} in start_simulation to limit parallelism, or reduce the model size for initial experiments.