Federated Learning – Variational Quantum Circuits¶
Examples of federated learning with variational quantum circuits.
Iris Dataset¶
fl_iris_vqc_example.py¶
1import torch
2
3from kosmos.ml.config.factories.encoding import AngleEmbeddingConfig
4from kosmos.ml.config.factories.loss import CrossEntropyLossConfig
5from kosmos.ml.config.factories.model import VQCConfig
6from kosmos.ml.config.factories.optimizer import AdamOptimizerConfig
7from kosmos.ml.config.sl_train import FLTrainConfig
8from kosmos.ml.datasets.iris_dataset import IrisDataset
9from kosmos.simulator.fl_simulator import FLSimulator
10from kosmos.topology.link import ClassicalLink, LinkId
11from kosmos.topology.net import Network
12from kosmos.topology.node import NodeId, NodeRole, QuantumNode
13
14
15def construct_network() -> Network:
16 """Construct network topology with two quantum clients and one quantum server."""
17 network = Network()
18
19 # Create quantum nodes (clients and server)
20 client_node_1 = QuantumNode(
21 id=NodeId("client_1"),
22 roles=[NodeRole.END_USER],
23 num_qubits=127,
24 coherence_time=4.0e-04,
25 )
26 client_node_2 = QuantumNode(
27 id=NodeId("client_2"),
28 roles=[NodeRole.END_USER],
29 num_qubits=127,
30 coherence_time=4.0e-04,
31 )
32 server_node = QuantumNode(
33 id=NodeId("server"),
34 roles=[NodeRole.END_USER],
35 num_qubits=127,
36 coherence_time=4.0e-04,
37 )
38
39 # Create classical links connecting clients to the server
40 client_server_link_1 = ClassicalLink(
41 id=LinkId("c1_server"),
42 src=client_node_1,
43 dst=server_node,
44 distance=1000.0,
45 attenuation=0.0002,
46 signal_speed=0.0002,
47 bandwidth=10e9,
48 )
49 client_server_link_2 = ClassicalLink(
50 id=LinkId("c2_server"),
51 src=client_node_2,
52 dst=server_node,
53 distance=1000.0,
54 attenuation=0.0002,
55 signal_speed=0.0002,
56 bandwidth=10e9,
57 )
58
59 # Add nodes and links to the network
60 network.add_node(client_node_1)
61 network.add_node(client_node_2)
62 network.add_node(server_node)
63 network.add_link(client_server_link_1)
64 network.add_link(client_server_link_2)
65
66 return network
67
68
69def fl_iris_vqc_example() -> None:
70 """Run example of federated training and testing on the Iris dataset using a VQC."""
71 network = construct_network()
72
73 # Dataset to train and test on
74 dataset = IrisDataset()
75
76 # Model configuration that defines the variational quantum circuit to use
77 vqc_config = VQCConfig(
78 num_layers=2,
79 encoding_config=AngleEmbeddingConfig(rotation="X"),
80 weight_mapping_func=lambda x: torch.pi * torch.tanh(x),
81 input_mapping_func=lambda x: torch.pi * x,
82 weight_init_range=(-1, 1),
83 bias_init_range=(-0.001, 0.001),
84 data_reuploading=True,
85 output_scaling=True,
86 )
87
88 # Configure federated learning
89 train_config = FLTrainConfig(
90 dataset=dataset,
91 train_split=0.7,
92 batch_size=8,
93 num_epochs=5,
94 model_config=vqc_config,
95 optimizer_config=AdamOptimizerConfig(lr=0.01),
96 lr_scheduler_config=None,
97 max_grad_norm=1.0,
98 loss_config=CrossEntropyLossConfig(),
99 num_rounds=5,
100 )
101
102 # Initialize simulator, which is responsible for running the federated learning experiment
103 simulator = FLSimulator(
104 network,
105 train_config,
106 client_nodes=["client_1", "client_2"],
107 server_node="server",
108 seed=1,
109 )
110
111 # Run training
112 for epoch_result in simulator.train():
113 print(epoch_result) # noqa: T201
114
115 # Evaluate trained model
116 print(simulator.test()) # noqa: T201
117
118
119if __name__ == "__main__":
120 fl_iris_vqc_example()