Federated Learning – Variational Quantum Circuits

Examples of federated learning with variational quantum circuits.


BloodMNIST Dataset

fl_bloodmnist_vqc_example.py
  1import torch
  2
  3from kosmos.circuit_runner.pennylane_runner import PennyLaneRunner
  4from kosmos.ml.config.factories.encoding import AmplitudeEmbeddingConfig
  5from kosmos.ml.config.factories.loss import CrossEntropyLossConfig
  6from kosmos.ml.config.factories.model import VQCConfig
  7from kosmos.ml.config.factories.optimizer import AdamOptimizerConfig
  8from kosmos.ml.config.sl_train import FLTrainConfig
  9from kosmos.ml.datasets.bloodmnist_dataset import BloodMNISTDataset
 10from kosmos.simulator.fl_simulator import FLSimulator
 11from kosmos.topology.link import ClassicalLink, LinkId
 12from kosmos.topology.net import Network
 13from kosmos.topology.node import NodeId, NodeRole, QuantumNode
 14
 15
 16def construct_network() -> Network:
 17    """Construct network topology with two quantum clients and one quantum server."""
 18    network = Network()
 19
 20    # Create quantum nodes (clients and server)
 21    client_node_1 = QuantumNode(
 22        id=NodeId("client_1"),
 23        roles=[NodeRole.END_USER],
 24        num_qubits=127,
 25        coherence_time=4.0e-04,
 26    )
 27    client_node_2 = QuantumNode(
 28        id=NodeId("client_2"),
 29        roles=[NodeRole.END_USER],
 30        num_qubits=127,
 31        coherence_time=4.0e-04,
 32    )
 33    server_node = QuantumNode(
 34        id=NodeId("server"),
 35        roles=[NodeRole.END_USER],
 36        num_qubits=127,
 37        coherence_time=4.0e-04,
 38    )
 39
 40    # Create classical links connecting clients to the server
 41    client_server_link_1 = ClassicalLink(
 42        id=LinkId("c1_server"),
 43        src=client_node_1,
 44        dst=server_node,
 45        distance=1000.0,
 46        attenuation=0.0002,
 47        signal_speed=0.0002,
 48        bandwidth=10e9,
 49    )
 50    client_server_link_2 = ClassicalLink(
 51        id=LinkId("c2_server"),
 52        src=client_node_2,
 53        dst=server_node,
 54        distance=1000.0,
 55        attenuation=0.0002,
 56        signal_speed=0.0002,
 57        bandwidth=10e9,
 58    )
 59
 60    # Add nodes and links to the network
 61    network.add_node(client_node_1)
 62    network.add_node(client_node_2)
 63    network.add_node(server_node)
 64    network.add_link(client_server_link_1)
 65    network.add_link(client_server_link_2)
 66
 67    return network
 68
 69
 70def fl_bloodmnist_vqc_example() -> None:
 71    """Run example of federated training and testing on the BloodMNIST dataset using a VQC."""
 72    network = construct_network()
 73
 74    # Dataset to train and test on
 75    dataset = BloodMNISTDataset()
 76
 77    # Model configuration that defines the variational quantum circuit to use
 78    vqc_config = VQCConfig(
 79        circuit_runner=PennyLaneRunner(),
 80        num_layers=2,
 81        encoding_config=AmplitudeEmbeddingConfig(),
 82        weight_mapping_func=lambda x: torch.pi * torch.tanh(x),
 83        input_mapping_func=lambda x: torch.pi * x,
 84        weight_init_range=(-1, 1),
 85        bias_init_range=(-0.001, 0.001),
 86        data_reuploading=True,
 87        output_scaling=True,
 88    )
 89
 90    # Configure federated learning
 91    train_config = FLTrainConfig(
 92        dataset=dataset,
 93        train_split=0.7,
 94        batch_size=128,
 95        num_epochs=5,
 96        model_config=vqc_config,
 97        optimizer_config=AdamOptimizerConfig(lr=1e-3),
 98        lr_scheduler_config=None,
 99        max_grad_norm=1.0,
100        loss_config=CrossEntropyLossConfig(),
101        num_rounds=5,
102    )
103
104    # Initialize simulator, which is responsible for running the federated learning experiment
105    simulator = FLSimulator(
106        network,
107        train_config,
108        client_nodes=["client_1", "client_2"],
109        server_node="server",
110        seed=1,
111    )
112
113    # Run training
114    for epoch_result in simulator.train():
115        print(epoch_result)  # noqa: T201
116
117    # Evaluate trained model
118    print(simulator.test())  # noqa: T201
119
120
121if __name__ == "__main__":
122    fl_bloodmnist_vqc_example()

Iris Dataset

fl_iris_vqc_example.py
  1import torch
  2
  3from kosmos.circuit_runner.pennylane_runner import PennyLaneRunner
  4from kosmos.ml.config.factories.encoding import AngleEmbeddingConfig
  5from kosmos.ml.config.factories.loss import CrossEntropyLossConfig
  6from kosmos.ml.config.factories.model import VQCConfig
  7from kosmos.ml.config.factories.optimizer import AdamOptimizerConfig
  8from kosmos.ml.config.sl_train import FLTrainConfig
  9from kosmos.ml.datasets.iris_dataset import IrisDataset
 10from kosmos.simulator.fl_simulator import FLSimulator
 11from kosmos.topology.link import ClassicalLink, LinkId
 12from kosmos.topology.net import Network
 13from kosmos.topology.node import NodeId, NodeRole, QuantumNode
 14
 15
 16def construct_network() -> Network:
 17    """Construct network topology with two quantum clients and one quantum server."""
 18    network = Network()
 19
 20    # Create quantum nodes (clients and server)
 21    client_node_1 = QuantumNode(
 22        id=NodeId("client_1"),
 23        roles=[NodeRole.END_USER],
 24        num_qubits=127,
 25        coherence_time=4.0e-04,
 26    )
 27    client_node_2 = QuantumNode(
 28        id=NodeId("client_2"),
 29        roles=[NodeRole.END_USER],
 30        num_qubits=127,
 31        coherence_time=4.0e-04,
 32    )
 33    server_node = QuantumNode(
 34        id=NodeId("server"),
 35        roles=[NodeRole.END_USER],
 36        num_qubits=127,
 37        coherence_time=4.0e-04,
 38    )
 39
 40    # Create classical links connecting clients to the server
 41    client_server_link_1 = ClassicalLink(
 42        id=LinkId("c1_server"),
 43        src=client_node_1,
 44        dst=server_node,
 45        distance=1000.0,
 46        attenuation=0.0002,
 47        signal_speed=0.0002,
 48        bandwidth=10e9,
 49    )
 50    client_server_link_2 = ClassicalLink(
 51        id=LinkId("c2_server"),
 52        src=client_node_2,
 53        dst=server_node,
 54        distance=1000.0,
 55        attenuation=0.0002,
 56        signal_speed=0.0002,
 57        bandwidth=10e9,
 58    )
 59
 60    # Add nodes and links to the network
 61    network.add_node(client_node_1)
 62    network.add_node(client_node_2)
 63    network.add_node(server_node)
 64    network.add_link(client_server_link_1)
 65    network.add_link(client_server_link_2)
 66
 67    return network
 68
 69
 70def fl_iris_vqc_example() -> None:
 71    """Run example of federated training and testing on the Iris dataset using a VQC."""
 72    network = construct_network()
 73
 74    # Dataset to train and test on
 75    dataset = IrisDataset()
 76
 77    # Model configuration that defines the variational quantum circuit to use
 78    vqc_config = VQCConfig(
 79        circuit_runner=PennyLaneRunner(),
 80        num_layers=2,
 81        encoding_config=AngleEmbeddingConfig(rotation="X"),
 82        weight_mapping_func=lambda x: torch.pi * torch.tanh(x),
 83        input_mapping_func=lambda x: torch.pi * x,
 84        weight_init_range=(-1, 1),
 85        bias_init_range=(-0.001, 0.001),
 86        data_reuploading=True,
 87        output_scaling=True,
 88    )
 89
 90    # Configure federated learning
 91    train_config = FLTrainConfig(
 92        dataset=dataset,
 93        train_split=0.7,
 94        batch_size=8,
 95        num_epochs=5,
 96        model_config=vqc_config,
 97        optimizer_config=AdamOptimizerConfig(lr=0.01),
 98        lr_scheduler_config=None,
 99        max_grad_norm=1.0,
100        loss_config=CrossEntropyLossConfig(),
101        num_rounds=5,
102    )
103
104    # Initialize simulator, which is responsible for running the federated learning experiment
105    simulator = FLSimulator(
106        network,
107        train_config,
108        client_nodes=["client_1", "client_2"],
109        server_node="server",
110        seed=1,
111    )
112
113    # Run training
114    for epoch_result in simulator.train():
115        print(epoch_result)  # noqa: T201
116
117    # Evaluate trained model
118    print(simulator.test())  # noqa: T201
119
120
121if __name__ == "__main__":
122    fl_iris_vqc_example()

OrganAMNIST Dataset

fl_organamnist_vqc_example.py
  1import torch
  2
  3from kosmos.circuit_runner.pennylane_runner import PennyLaneRunner
  4from kosmos.ml.config.factories.encoding import AmplitudeEmbeddingConfig
  5from kosmos.ml.config.factories.loss import CrossEntropyLossConfig
  6from kosmos.ml.config.factories.model import VQCConfig
  7from kosmos.ml.config.factories.optimizer import AdamOptimizerConfig
  8from kosmos.ml.config.sl_train import FLTrainConfig
  9from kosmos.ml.datasets.organamnist_dataset import OrganAMNISTDataset
 10from kosmos.simulator.fl_simulator import FLSimulator
 11from kosmos.topology.link import ClassicalLink, LinkId
 12from kosmos.topology.net import Network
 13from kosmos.topology.node import NodeId, NodeRole, QuantumNode
 14
 15
 16def construct_network() -> Network:
 17    """Construct network topology with two quantum clients and one quantum server."""
 18    network = Network()
 19
 20    # Create quantum nodes (clients and server)
 21    client_node_1 = QuantumNode(
 22        id=NodeId("client_1"),
 23        roles=[NodeRole.END_USER],
 24        num_qubits=127,
 25        coherence_time=4.0e-04,
 26    )
 27    client_node_2 = QuantumNode(
 28        id=NodeId("client_2"),
 29        roles=[NodeRole.END_USER],
 30        num_qubits=127,
 31        coherence_time=4.0e-04,
 32    )
 33    server_node = QuantumNode(
 34        id=NodeId("server"),
 35        roles=[NodeRole.END_USER],
 36        num_qubits=127,
 37        coherence_time=4.0e-04,
 38    )
 39
 40    # Create classical links connecting clients to the server
 41    client_server_link_1 = ClassicalLink(
 42        id=LinkId("c1_server"),
 43        src=client_node_1,
 44        dst=server_node,
 45        distance=1000.0,
 46        attenuation=0.0002,
 47        signal_speed=0.0002,
 48        bandwidth=10e9,
 49    )
 50    client_server_link_2 = ClassicalLink(
 51        id=LinkId("c2_server"),
 52        src=client_node_2,
 53        dst=server_node,
 54        distance=1000.0,
 55        attenuation=0.0002,
 56        signal_speed=0.0002,
 57        bandwidth=10e9,
 58    )
 59
 60    # Add nodes and links to the network
 61    network.add_node(client_node_1)
 62    network.add_node(client_node_2)
 63    network.add_node(server_node)
 64    network.add_link(client_server_link_1)
 65    network.add_link(client_server_link_2)
 66
 67    return network
 68
 69
 70def fl_organamnist_vqc_example() -> None:
 71    """Run example of federated training and testing on the OrganAMNIST dataset using a VQC."""
 72    network = construct_network()
 73
 74    # Dataset to train and test on
 75    dataset = OrganAMNISTDataset()
 76
 77    # Model configuration that defines the variational quantum circuit to use
 78    vqc_config = VQCConfig(
 79        circuit_runner=PennyLaneRunner(),
 80        num_layers=2,
 81        encoding_config=AmplitudeEmbeddingConfig(),
 82        weight_mapping_func=lambda x: torch.pi * torch.tanh(x),
 83        input_mapping_func=lambda x: torch.pi * x,
 84        weight_init_range=(-1, 1),
 85        bias_init_range=(-0.001, 0.001),
 86        data_reuploading=True,
 87        output_scaling=True,
 88    )
 89
 90    # Configure federated learning
 91    train_config = FLTrainConfig(
 92        dataset=dataset,
 93        train_split=0.7,
 94        batch_size=128,
 95        num_epochs=5,
 96        model_config=vqc_config,
 97        optimizer_config=AdamOptimizerConfig(lr=1e-3),
 98        lr_scheduler_config=None,
 99        max_grad_norm=1.0,
100        loss_config=CrossEntropyLossConfig(),
101        num_rounds=5,
102    )
103
104    # Initialize simulator, which is responsible for running the federated learning experiment
105    simulator = FLSimulator(
106        network,
107        train_config,
108        client_nodes=["client_1", "client_2"],
109        server_node="server",
110        seed=1,
111    )
112
113    # Run training
114    for epoch_result in simulator.train():
115        print(epoch_result)  # noqa: T201
116
117    # Evaluate trained model
118    print(simulator.test())  # noqa: T201
119
120
121if __name__ == "__main__":
122    fl_organamnist_vqc_example()