Federated Learning – Variational Quantum Circuits

Examples of federated learning with variational quantum circuits.


Iris Dataset

fl_iris_vqc_example.py
  1import torch
  2
  3from kosmos.ml.config.factories.encoding import AngleEmbeddingConfig
  4from kosmos.ml.config.factories.loss import CrossEntropyLossConfig
  5from kosmos.ml.config.factories.model import VQCConfig
  6from kosmos.ml.config.factories.optimizer import AdamOptimizerConfig
  7from kosmos.ml.config.sl_train import FLTrainConfig
  8from kosmos.ml.datasets.iris_dataset import IrisDataset
  9from kosmos.simulator.fl_simulator import FLSimulator
 10from kosmos.topology.link import ClassicalLink, LinkId
 11from kosmos.topology.net import Network
 12from kosmos.topology.node import NodeId, NodeRole, QuantumNode
 13
 14
 15def construct_network() -> Network:
 16    """Construct network topology with two quantum clients and one quantum server."""
 17    network = Network()
 18
 19    # Create quantum nodes (clients and server)
 20    client_node_1 = QuantumNode(
 21        id=NodeId("client_1"),
 22        roles=[NodeRole.END_USER],
 23        num_qubits=127,
 24        coherence_time=4.0e-04,
 25    )
 26    client_node_2 = QuantumNode(
 27        id=NodeId("client_2"),
 28        roles=[NodeRole.END_USER],
 29        num_qubits=127,
 30        coherence_time=4.0e-04,
 31    )
 32    server_node = QuantumNode(
 33        id=NodeId("server"),
 34        roles=[NodeRole.END_USER],
 35        num_qubits=127,
 36        coherence_time=4.0e-04,
 37    )
 38
 39    # Create classical links connecting clients to the server
 40    client_server_link_1 = ClassicalLink(
 41        id=LinkId("c1_server"),
 42        src=client_node_1,
 43        dst=server_node,
 44        distance=1000.0,
 45        attenuation=0.0002,
 46        signal_speed=0.0002,
 47        bandwidth=10e9,
 48    )
 49    client_server_link_2 = ClassicalLink(
 50        id=LinkId("c2_server"),
 51        src=client_node_2,
 52        dst=server_node,
 53        distance=1000.0,
 54        attenuation=0.0002,
 55        signal_speed=0.0002,
 56        bandwidth=10e9,
 57    )
 58
 59    # Add nodes and links to the network
 60    network.add_node(client_node_1)
 61    network.add_node(client_node_2)
 62    network.add_node(server_node)
 63    network.add_link(client_server_link_1)
 64    network.add_link(client_server_link_2)
 65
 66    return network
 67
 68
 69def fl_iris_vqc_example() -> None:
 70    """Run example of federated training and testing on the Iris dataset using a VQC."""
 71    network = construct_network()
 72
 73    # Dataset to train and test on
 74    dataset = IrisDataset()
 75
 76    # Model configuration that defines the variational quantum circuit to use
 77    vqc_config = VQCConfig(
 78        num_layers=2,
 79        encoding_config=AngleEmbeddingConfig(rotation="X"),
 80        weight_mapping_func=lambda x: torch.pi * torch.tanh(x),
 81        input_mapping_func=lambda x: torch.pi * x,
 82        weight_init_range=(-1, 1),
 83        bias_init_range=(-0.001, 0.001),
 84        data_reuploading=True,
 85        output_scaling=True,
 86    )
 87
 88    # Configure federated learning
 89    train_config = FLTrainConfig(
 90        dataset=dataset,
 91        train_split=0.7,
 92        batch_size=8,
 93        num_epochs=5,
 94        model_config=vqc_config,
 95        optimizer_config=AdamOptimizerConfig(lr=0.01),
 96        lr_scheduler_config=None,
 97        max_grad_norm=1.0,
 98        loss_config=CrossEntropyLossConfig(),
 99        num_rounds=5,
100    )
101
102    # Initialize simulator, which is responsible for running the federated learning experiment
103    simulator = FLSimulator(
104        network,
105        train_config,
106        client_nodes=["client_1", "client_2"],
107        server_node="server",
108        seed=1,
109    )
110
111    # Run training
112    for epoch_result in simulator.train():
113        print(epoch_result)  # noqa: T201
114
115    # Evaluate trained model
116    print(simulator.test())  # noqa: T201
117
118
119if __name__ == "__main__":
120    fl_iris_vqc_example()