Set Up a Flower Server and Client in Under 50 Lines

Federated learning keeps data on the devices that generated it. Instead of shipping data to a central server, you ship the model to the data, train locally, and send only weight updates back. Flower (flwr) is the best framework for this right now – it handles the orchestration so you can focus on your model.

Install the dependencies:

1
pip install flwr[simulation] torch torchvision

Here’s a complete working example. First, define your model and a Flower client that trains it locally:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import flwr as fl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from collections import OrderedDict

# Simple CNN -- swap this for whatever architecture you need
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 32 * 7 * 7)
        x = self.relu(self.fc1(x))
        return self.fc2(x)


def get_parameters(model):
    return [val.cpu().numpy() for val in model.state_dict().values()]


def set_parameters(model, parameters):
    params = zip(model.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.tensor(v) for k, v in params})
    model.load_state_dict(state_dict, strict=True)


class FlowerClient(fl.client.NumPyClient):
    def __init__(self, model, trainloader, testloader):
        self.model = model
        self.trainloader = trainloader
        self.testloader = testloader

    def get_parameters(self, config):
        return get_parameters(self.model)

    def fit(self, parameters, config):
        set_parameters(self.model, parameters)
        train(self.model, self.trainloader, epochs=1)
        return get_parameters(self.model), len(self.trainloader.dataset), {}

    def evaluate(self, parameters, config):
        set_parameters(self.model, parameters)
        loss, accuracy = test(self.model, self.testloader)
        return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)}


def train(model, loader, epochs=1):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    model.train()
    for _ in range(epochs):
        for images, labels in loader:
            optimizer.zero_grad()
            criterion(model(images), labels).backward()
            optimizer.step()


def test(model, loader):
    criterion = nn.CrossEntropyLoss()
    model.eval()
    loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in loader:
            outputs = model(images)
            loss += criterion(outputs, labels).item()
            correct += (outputs.argmax(1) == labels).sum().item()
            total += labels.size(0)
    return loss / len(loader), correct / total

Run the Simulation

Flower’s simulation engine lets you test federated training without spinning up separate processes. This is the fastest way to iterate:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from flwr.simulation import start_simulation

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

full_train = datasets.MNIST("./data", train=True, download=True, transform=transform)
full_test = datasets.MNIST("./data", train=False, transform=transform)

NUM_CLIENTS = 10

# Split data across clients (IID split for now)
partition_size = len(full_train) // NUM_CLIENTS
indices = list(range(len(full_train)))


def client_fn(cid: str) -> fl.client.Client:
    idx = int(cid)
    start = idx * partition_size
    end = start + partition_size
    trainset = Subset(full_train, indices[start:end])
    # Each client gets a slice of the test set too
    test_start = idx * (len(full_test) // NUM_CLIENTS)
    test_end = test_start + (len(full_test) // NUM_CLIENTS)
    testset = Subset(full_test, list(range(test_start, test_end)))

    model = Net()
    return FlowerClient(
        model,
        DataLoader(trainset, batch_size=32, shuffle=True),
        DataLoader(testset, batch_size=32),
    ).to_client()


# FedAvg is the default and best starting point
strategy = fl.server.strategy.FedAvg(
    fraction_fit=0.5,           # train on 50% of clients per round
    fraction_evaluate=0.3,      # evaluate on 30%
    min_fit_clients=3,          # minimum 3 clients per round
    min_available_clients=NUM_CLIENTS,
)

start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=5),
    strategy=strategy,
)

After 5 rounds with 10 clients, you should see the global model reaching around 97% accuracy on MNIST. Each round, only 5 clients train (50% sampling), and the server averages their weight updates using FedAvg.

Handle Non-IID Data

Real-world federated data is almost never IID. One hospital might see mostly cardiac cases while another specializes in orthopedics. Simulating this is critical for realistic experiments.

The simplest non-IID split assigns each client only a subset of the label classes:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import numpy as np

def create_non_iid_split(dataset, num_clients, classes_per_client=2):
    """Give each client data from only `classes_per_client` classes."""
    label_indices = {}
    for idx, (_, label) in enumerate(dataset):
        label_indices.setdefault(label, []).append(idx)

    client_indices = [[] for _ in range(num_clients)]
    all_labels = sorted(label_indices.keys())

    for i in range(num_clients):
        # Assign 2 classes per client, rotating through labels
        assigned = [all_labels[j % len(all_labels)]
                    for j in range(i * classes_per_client,
                                   (i + 1) * classes_per_client)]
        for label in assigned:
            # Give each client an equal chunk of their assigned classes
            idxs = label_indices[label]
            chunk_size = len(idxs) // num_clients
            start = i * chunk_size
            client_indices[i].extend(idxs[start:start + chunk_size])

    return client_indices

When you switch to non-IID splits, expect the global model to converge slower and to a lower final accuracy. Two strategies help: increase num_rounds (more communication rounds) and increase fraction_fit (more clients per round to get better coverage of the label space).

Add Differential Privacy to Federated Rounds

Federated learning alone is not enough for strong privacy. Weight updates can still leak information about individual training samples. You need differential privacy on top.

The approach is straightforward: clip the norm of each client’s model update and add Gaussian noise at the server before aggregation. Flower lets you build a custom strategy for this:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from flwr.server.strategy import FedAvg
from flwr.common import FitRes, Parameters, Scalar
import numpy as np
from typing import List, Tuple, Optional, Dict, Union
from flwr.server.client_proxy import ClientProxy
from flwr.common import ndarrays_to_parameters, parameters_to_ndarrays


class DPFedAvg(FedAvg):
    def __init__(self, noise_multiplier: float = 1.0,
                 clipping_norm: float = 1.0, **kwargs):
        super().__init__(**kwargs)
        self.noise_multiplier = noise_multiplier
        self.clipping_norm = clipping_norm

    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        if not results:
            return None, {}

        # Extract and clip updates
        clipped_updates = []
        for _, fit_res in results:
            weights = parameters_to_ndarrays(fit_res.parameters)
            # Compute L2 norm of the full update
            flat = np.concatenate([w.flatten() for w in weights])
            norm = np.linalg.norm(flat)
            # Clip if needed
            scale = min(1.0, self.clipping_norm / (norm + 1e-8))
            clipped = [w * scale for w in weights]
            clipped_updates.append(clipped)

        # Average the clipped updates
        num_clients = len(clipped_updates)
        avg_update = [
            np.mean(np.array([u[i] for u in clipped_updates]), axis=0)
            for i in range(len(clipped_updates[0]))
        ]

        # Add calibrated Gaussian noise
        sensitivity = self.clipping_norm / num_clients
        noise_std = sensitivity * self.noise_multiplier
        noisy_update = [
            w + np.random.normal(0, noise_std, size=w.shape)
            for w in avg_update
        ]

        return ndarrays_to_parameters(noisy_update), {}

Use it like any other strategy:

1
2
3
4
5
6
7
dp_strategy = DPFedAvg(
    noise_multiplier=1.0,
    clipping_norm=1.0,
    fraction_fit=0.5,
    min_fit_clients=3,
    min_available_clients=NUM_CLIENTS,
)

The noise_multiplier controls the privacy-utility tradeoff. Higher values give stronger privacy (lower epsilon) but hurt accuracy. Start with 1.0 and tune from there. A clipping_norm of 1.0 works well for most CNNs – increase it if you see the model failing to learn at all.

Evaluate the Global Model

After each federated round, Flower calls evaluate on the sampled clients. But you also want to evaluate the aggregated global model on a held-out centralized test set. Add an evaluate_fn to your strategy:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
def get_evaluate_fn(testloader):
    def evaluate(server_round, parameters, config):
        model = Net()
        set_parameters(model, parameters_to_ndarrays(parameters)
                       if isinstance(parameters, Parameters)
                       else parameters)
        loss, accuracy = test(model, testloader)
        print(f"Round {server_round}: loss={loss:.4f}, accuracy={accuracy:.4f}")
        return loss, {"accuracy": accuracy}
    return evaluate

central_testloader = DataLoader(full_test, batch_size=64)

strategy = fl.server.strategy.FedAvg(
    fraction_fit=0.5,
    min_fit_clients=3,
    min_available_clients=NUM_CLIENTS,
    evaluate_fn=get_evaluate_fn(central_testloader),
)

This runs after every aggregation round and gives you a clean accuracy curve for the global model, independent of which clients happened to participate.

Common Errors and Fixes

RuntimeError: Sizes of tensors must match – This happens when your model architecture differs between clients or between the client and server. Every participant must use the exact same model class with the same hyperparameters. Double-check that all clients instantiate the same Net().

grpc._channel._InactiveRpcError – The server isn’t running or the client can’t reach it. In simulation mode you won’t see this, but in a real deployment, make sure the server starts before clients connect. Set server_address="0.0.0.0:8080" and verify firewall rules.

Non-IID training produces random-chance accuracy – If each client only sees 1-2 classes, local models become extremely biased. The fix: increase fraction_fit to 0.8+ so more clients contribute each round, or use FedProx instead of FedAvg (set proximal_mu=0.1 in FedProx strategy) to keep local models closer to the global model.

DP noise destroys model accuracy – You added too much noise relative to the number of clients. Reduce noise_multiplier, increase clipping_norm, or add more clients. The noise scales with clipping_norm / num_clients, so more participants means less noise per parameter.

ValueError: NumPy array is not writable – Some versions of Flower return read-only arrays from parameters_to_ndarrays. Fix it by calling np.copy() on each array before modifying it in place.

Memory issues with many simulated clients – Flower’s simulation creates all client models in the same process. Set ray_init_args={"num_cpus": 4} in start_simulation to limit parallelism, or reduce the model size for initial experiments.