Resources:
- https://github.com/s222416822/QFL_Basics/blob/master/QFLBasic.ipynb
- https://qiskit-community.github.io/qiskit-machine-learning/tutorials/02a_training_a_quantum_model_on_a_real_dataset.html
- https://qiskit-community.github.io/qiskit-machine-learning/getting_started.html
What is Quantum Federated Learning?
Quantum Federated Learning (QFL) is a decentralized machine learning framework where multiple clients collaboratively train a global quantum model without sharing their private, local data. In this framework, clients perform local training on their own datasets and then share only their updated model weights with a central server for aggregation.
Code Outline and Explanation
The provided code implements QFL using the Iris dataset split across three simulated clients. Below is an outline of the technical steps:
1. Data Preparation
- Loading and Filtering: The code loads the Iris dataset and filters it to include only two classes (binary classification).
- Normalization: It uses
MinMaxScalerto scale the four features into a range suitable for quantum encoding. - Decentralization: The dataset is shuffled and split into three separate parts, simulating three independent clients who keep their data local.
2. Quantum Model Design
- Feature Map: The code uses a
ZFeatureMapto translate classical data into quantum states. - Ansatz: A
RealAmplitudesvariational circuit serves as the trainable part of the model. - SamplerQNN: This Qiskit primitive integrates the circuit with a
parityfunction to map quantum measurements to classical labels (0 or 1).
3. Federated Training Loop
The code executes 10 federated rounds. In each round, the following occurs:
- Global Evaluation: The new global weights are used to test accuracy on a separate test set to track how the collective model is improving.
- Local Training: Each client receives the current Global Weights, applies them to their local model, and trains on their local data for 5 iterations.
- Weight Aggregation: After local training, the server collects the weights from all clients and calculates the average (
FedAvg) to update the Global Weights.
In [1]:
!pip install qiskit==1.4.1 qiskit_machine_learning qiskit_aer qiskit_ibm_runtime -q
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 62.0/62.0 kB 3.2 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.7/6.7 MB 52.3 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 231.9/231.9 kB 16.4 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 12.4/12.4 MB 67.6 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.5/1.5 MB 46.3 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 378.6/378.6 kB 22.2 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.2/2.2 MB 39.1 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 37.3/37.3 MB 18.1 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 54.4/54.4 kB 4.9 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 49.6/49.6 MB 12.2 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 75.8/75.8 kB 5.5 MB/s eta 0:00:00 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 130.2/130.2 kB 9.6 MB/s eta 0:00:00
In [2]:
# ---------------------------------------------------------------------------------------------------------------------------------------#
# Various Resources are being used and the links provided:
# Main Resource referenced: https://qiskit-community.github.io/qiskit-machine-learning/tutorials/02a_training_a_quantum_model_on_a_real_dataset.html
# For further detail, quantum machine learning https://qiskit-community.github.io/qiskit-machine-learning/getting_started.html
# ---------------------------------------------------------------------------------------------------------------------------------------#
from qiskit_aer import AerSimulator
from qiskit_ibm_runtime import Session, SamplerV2 as Sampler, QiskitRuntimeService
from qiskit.transpiler.preset_passmanagers import generate_preset_pass_manager
from sklearn.metrics import accuracy_score, log_loss
from qiskit.circuit.library import RealAmplitudes, ZZFeatureMap, ZFeatureMap
from qiskit_machine_learning.optimizers import COBYLA
from qiskit_machine_learning.algorithms.classifiers import NeuralNetworkClassifier
from sklearn.preprocessing import MinMaxScaler
from qiskit_machine_learning.neural_networks import SamplerQNN
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
#Transpile required for execucation on real hardware
#https://quantum.cloud.ibm.com/docs/en/api/qiskit/transpiler
from qiskit.transpiler.preset_passmanagers import generate_preset_pass_manager
from qiskit_ibm_runtime.fake_provider import FakeManilaV2, FakeBrisbane
from qiskit.providers.fake_provider import GenericBackendV2
from qiskit_ibm_runtime import SamplerV2 as Sampler
from qiskit import QuantumCircuit, ClassicalRegister, QuantumRegister
# 1. PREPARE DATA https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_iris.html
iris_data = load_iris()
features = iris_data.data
labels = iris_data.target
# You can select the labels, customize your experiment
mask = labels < 2
features = features[mask]
labels = labels[mask]
# https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MinMaxScaler.html
scaler = MinMaxScaler()
features = scaler.fit_transform(features)
#https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html
X_train, X_test, y_train, y_test = train_test_split(features, labels, train_size=0.9, random_state=42)
#https://qiskit-community.github.io/qiskit-machine-learning/tutorials/02a_training_a_quantum_model_on_a_real_dataset.html
# Design or Select Classifier
# Here we are using NeuralNetworkClassifier with SamplerQNN; We need FeatureMap and Ansatz
num_features = features.shape[1]
#The ZFeatureMap is one of feature maps in the Qiskit circuit library. We pass num_features as feature_dimension, meaning the feature map will have num_features or 4 qubits.
feature_map = ZFeatureMap(feature_dimension=num_features, reps=1)
ansatz = RealAmplitudes(num_qubits=num_features, reps=3)
ansatz.measure_all() #measurement operation to all qubits
backend = AerSimulator()
sampler = Sampler(mode=backend)
optimizer = COBYLA(maxiter=10)
pm = generate_preset_pass_manager(optimization_level=1, backend=backend)
input_params = feature_map.parameters
weight_params = ansatz.parameters
# Compose the feature map and ansatz
qc = QuantumCircuit(num_features)
qc.compose(feature_map, inplace=True)
qc.compose(ansatz, inplace=True)
# Define a custom interpret function that calculates the parity of the bitstring
# function that takes a bitstring (represented as an integer) and returns its parity. The parity is typically defined as 0 if the number of '1's in the bitstring is even, and 1 if it's odd.
def parity(x):
return f"{bin(x)}".count("1") % 2
# https://qiskit-community.github.io/qiskit-machine-learning/stubs/qiskit_machine_learning.neural_networks.SamplerQNN.html
sampler_qnn = SamplerQNN(
circuit=qc,
input_params=input_params,
weight_params=weight_params,
sampler=sampler,
output_shape=2,
gradient=None,
interpret=parity,
pass_manager=pm
)
# Define number of clients, how many rounds and local iteration
n_clients = 3
federated_rounds = 10
local_maxiter = 5
# When it is set to True and fit() is called again the model uses weights from previous fit to start a new fit.
check = True #start from previously optimized state
# Randomize dataset
indices = np.random.permutation(len(X_train))
X_shuffled = X_train[indices]
y_shuffled = y_train[indices]
X_splits = np.array_split(X_shuffled, n_clients)
y_splits = np.array_split(y_shuffled, n_clients)
from sklearn.model_selection import train_test_split
clients = []
for i in range(n_clients):
X_client = X_splits[i]
y_client = y_splits[i].astype(int)
# Further split data withint device to calculate local and test accuracy for devices
X_train, X_test, y_train, y_test = train_test_split(
X_client,
y_client,
train_size=0.8,
stratify=y_client
)
classifier = NeuralNetworkClassifier(
neural_network=sampler_qnn,
optimizer=COBYLA(maxiter=local_maxiter),
warm_start=check
)
clients.append({
'id': i,
'X_train': X_train,
'X_test': X_test,
'y_train': y_train,
'y_test': y_test,
'model': classifier,
'train_scores': [],
'test_scores': [],
'training_times': [],
'n_total_samples': len(X_client),
'n_train_samples': len(X_train),
'n_test_samples': len(X_test)
})
print(f"Client {i+1}: {len(X_client)} samples → "
f"{len(X_train)} train, {len(X_test)} local test")
global_weights = None
global_accuracies = []
print(f"Starting Federated Learning: {n_clients} clients, {federated_rounds} rounds\n")
for round_num in range(1, federated_rounds + 1):
print(f"--- Round {round_num}/{federated_rounds} ---")
client_weights = []
for client in clients:
print(f" Client {client['id']+1}/{n_clients}: training on {len(client['X_train'])} samples...")
if round_num > 0:
#continue training from that point by setting initial_point to a vector of pre-trained weights.
# https://qiskit-community.github.io/qiskit-machine-learning/tutorials/11_quantum_convolutional_neural_networks.html
client['model'].initial_point = global_weights
client['model'].fit(client['X_train'], client['y_train'])
train_score = client['model'].score(client['X_train'], client['y_train'])
test_score = client['model'].score(client['X_test'], client['y_test'])
client['train_scores'].append(train_score)
client['test_scores'].append(test_score)
print(f" Quantum VQC on the training dataset: {train_score:.4f}")
print(f" Quantum VQC on the test dataset: {test_score:.4f}")
client_weights.append(client['model'].weights)
global_weights = np.mean(client_weights, axis=0)
print(f" Aggregated global weights (FedAvg)")
probs = sampler_qnn.forward(X_test, global_weights)
y_pred = np.argmax(probs, axis=1)
acc = accuracy_score(y_test, y_pred)
global_accuracies.append(acc)
print(f" Global Test Accuracy: {acc:.4f}\n")
print(f"Final Global Accuracy: {global_accuracies[-1]:.4f}")
plt.figure(figsize=(5, 6))
plt.plot(range(1, federated_rounds + 1), global_accuracies, linewidth=2)
plt.title("Global Performance", fontsize=16)
plt.xlabel("Round", fontsize=14)
plt.ylabel("Accuracy", fontsize=14)
plt.show()
ERROR:stevedore.extension:Could not load 'ibm_backend': cannot import name 'calc_final_ops' from 'qiskit.transpiler.passes.utils.remove_final_measurements' (/usr/local/lib/python3.12/dist-packages/qiskit/transpiler/passes/utils/remove_final_measurements.py) ERROR:stevedore.extension:Could not load 'ibm_dynamic_and_fractional': cannot import name 'calc_final_ops' from 'qiskit.transpiler.passes.utils.remove_final_measurements' (/usr/local/lib/python3.12/dist-packages/qiskit/transpiler/passes/utils/remove_final_measurements.py) ERROR:stevedore.extension:Could not load 'ibm_dynamic_circuits': cannot import name 'calc_final_ops' from 'qiskit.transpiler.passes.utils.remove_final_measurements' (/usr/local/lib/python3.12/dist-packages/qiskit/transpiler/passes/utils/remove_final_measurements.py) ERROR:stevedore.extension:Could not load 'ibm_fractional': cannot import name 'calc_final_ops' from 'qiskit.transpiler.passes.utils.remove_final_measurements' (/usr/local/lib/python3.12/dist-packages/qiskit/transpiler/passes/utils/remove_final_measurements.py) ERROR:stevedore.extension:Could not load 'ibm_backend': cannot import name 'calc_final_ops' from 'qiskit.transpiler.passes.utils.remove_final_measurements' (/usr/local/lib/python3.12/dist-packages/qiskit/transpiler/passes/utils/remove_final_measurements.py) ERROR:stevedore.extension:Could not load 'ibm_dynamic_and_fractional': cannot import name 'calc_final_ops' from 'qiskit.transpiler.passes.utils.remove_final_measurements' (/usr/local/lib/python3.12/dist-packages/qiskit/transpiler/passes/utils/remove_final_measurements.py) ERROR:stevedore.extension:Could not load 'ibm_dynamic_circuits': cannot import name 'calc_final_ops' from 'qiskit.transpiler.passes.utils.remove_final_measurements' (/usr/local/lib/python3.12/dist-packages/qiskit/transpiler/passes/utils/remove_final_measurements.py) ERROR:stevedore.extension:Could not load 'ibm_fractional': cannot import name 'calc_final_ops' from 'qiskit.transpiler.passes.utils.remove_final_measurements' (/usr/local/lib/python3.12/dist-packages/qiskit/transpiler/passes/utils/remove_final_measurements.py)
Client 1: 30 samples → 24 train, 6 local test Client 2: 30 samples → 24 train, 6 local test Client 3: 30 samples → 24 train, 6 local test Starting Federated Learning: 3 clients, 10 rounds --- Round 1/10 --- Client 1/3: training on 24 samples... Quantum VQC on the training dataset: 0.6667 Quantum VQC on the test dataset: 0.6667 Client 2/3: training on 24 samples... Quantum VQC on the training dataset: 0.8750 Quantum VQC on the test dataset: 0.5000 Client 3/3: training on 24 samples... Quantum VQC on the training dataset: 0.5417 Quantum VQC on the test dataset: 0.5000 Aggregated global weights (FedAvg) Global Test Accuracy: 0.1667 --- Round 2/10 --- Client 1/3: training on 24 samples... Quantum VQC on the training dataset: 0.6250 Quantum VQC on the test dataset: 0.5000 Client 2/3: training on 24 samples... Quantum VQC on the training dataset: 0.9167 Quantum VQC on the test dataset: 0.8333 Client 3/3: training on 24 samples... Quantum VQC on the training dataset: 0.5833 Quantum VQC on the test dataset: 0.5000 Aggregated global weights (FedAvg) Global Test Accuracy: 0.3333 --- Round 3/10 --- Client 1/3: training on 24 samples... Quantum VQC on the training dataset: 0.6250 Quantum VQC on the test dataset: 0.6667 Client 2/3: training on 24 samples... Quantum VQC on the training dataset: 0.9583 Quantum VQC on the test dataset: 1.0000 Client 3/3: training on 24 samples... Quantum VQC on the training dataset: 0.5833 Quantum VQC on the test dataset: 0.5000 Aggregated global weights (FedAvg) Global Test Accuracy: 0.6667 --- Round 4/10 --- Client 1/3: training on 24 samples... Quantum VQC on the training dataset: 0.6250 Quantum VQC on the test dataset: 0.6667 Client 2/3: training on 24 samples... Quantum VQC on the training dataset: 0.9583 Quantum VQC on the test dataset: 0.8333 Client 3/3: training on 24 samples... Quantum VQC on the training dataset: 0.5833 Quantum VQC on the test dataset: 0.5000 Aggregated global weights (FedAvg) Global Test Accuracy: 0.6667 --- Round 5/10 --- Client 1/3: training on 24 samples... Quantum VQC on the training dataset: 0.5833 Quantum VQC on the test dataset: 0.6667 Client 2/3: training on 24 samples... Quantum VQC on the training dataset: 0.9583 Quantum VQC on the test dataset: 0.8333 Client 3/3: training on 24 samples... Quantum VQC on the training dataset: 0.5833 Quantum VQC on the test dataset: 0.5000 Aggregated global weights (FedAvg) Global Test Accuracy: 0.6667 --- Round 6/10 --- Client 1/3: training on 24 samples... Quantum VQC on the training dataset: 0.6250 Quantum VQC on the test dataset: 0.6667 Client 2/3: training on 24 samples... Quantum VQC on the training dataset: 0.9167 Quantum VQC on the test dataset: 0.8333 Client 3/3: training on 24 samples... Quantum VQC on the training dataset: 0.5833 Quantum VQC on the test dataset: 0.5000 Aggregated global weights (FedAvg) Global Test Accuracy: 0.6667 --- Round 7/10 --- Client 1/3: training on 24 samples... Quantum VQC on the training dataset: 0.6250 Quantum VQC on the test dataset: 0.6667 Client 2/3: training on 24 samples... Quantum VQC on the training dataset: 0.9583 Quantum VQC on the test dataset: 0.8333 Client 3/3: training on 24 samples... Quantum VQC on the training dataset: 0.5833 Quantum VQC on the test dataset: 0.5000 Aggregated global weights (FedAvg) Global Test Accuracy: 0.6667 --- Round 8/10 --- Client 1/3: training on 24 samples... Quantum VQC on the training dataset: 0.6667 Quantum VQC on the test dataset: 0.6667 Client 2/3: training on 24 samples... Quantum VQC on the training dataset: 0.9583 Quantum VQC on the test dataset: 0.8333 Client 3/3: training on 24 samples... Quantum VQC on the training dataset: 0.5833 Quantum VQC on the test dataset: 0.5000 Aggregated global weights (FedAvg) Global Test Accuracy: 0.5000 --- Round 9/10 --- Client 1/3: training on 24 samples... Quantum VQC on the training dataset: 0.6250 Quantum VQC on the test dataset: 0.6667 Client 2/3: training on 24 samples... Quantum VQC on the training dataset: 0.9583 Quantum VQC on the test dataset: 0.8333 Client 3/3: training on 24 samples... Quantum VQC on the training dataset: 0.5833 Quantum VQC on the test dataset: 0.5000 Aggregated global weights (FedAvg) Global Test Accuracy: 0.6667 --- Round 10/10 --- Client 1/3: training on 24 samples... Quantum VQC on the training dataset: 0.6250 Quantum VQC on the test dataset: 0.6667 Client 2/3: training on 24 samples... Quantum VQC on the training dataset: 0.9583 Quantum VQC on the test dataset: 0.8333 Client 3/3: training on 24 samples... Quantum VQC on the training dataset: 0.5833 Quantum VQC on the test dataset: 0.5000 Aggregated global weights (FedAvg) Global Test Accuracy: 0.6667 Final Global Accuracy: 0.6667
In [3]:
client = clients[0]
plt.figure(figsize=(5, 5))
plt.plot(client['train_scores'], label='Train Acc', linewidth=2)
plt.plot(client['test_scores'], label='Test Acc', linewidth=2)
plt.title(f"Client {client['id']} Performance")
plt.xlabel("Round")
plt.ylabel("Accuracy")
plt.legend()
plt.show()
client = clients[2]
plt.figure(figsize=(5, 5))
plt.plot(client['train_scores'], label='Train Acc', linewidth=2)
plt.plot(client['test_scores'], label='Test Acc', linewidth=2)
plt.title(f"Client {client['id']} Performance")
plt.xlabel("Round")
plt.ylabel("Accuracy")
plt.legend()
plt.show()
graph LR
subgraph Server [Central Server]
Global[Global Model]
Avg[FedAvg Aggregation]
end
subgraph Clients [Local Clients]
C1[Client 1]
C2[Client 2]
C3[Client 3]
end
C1 -- Weights --> Avg
C2 -- Weights --> Avg
C3 -- Weights --> Avg
Avg -- Updated Model --> Global
Global -- Broadcast --> C1 & C2 & C3