Federated Learning – Neural Networks

Examples of federated learning with neural networks.


Iris Dataset

fl_iris_nn_example.py
  1from kosmos.ml.config.factories.loss import CrossEntropyLossConfig
  2from kosmos.ml.config.factories.lr_scheduler import CosineLearningRateSchedulerConfig
  3from kosmos.ml.config.factories.model import NeuralNetworkConfig
  4from kosmos.ml.config.factories.optimizer import AdamOptimizerConfig
  5from kosmos.ml.config.sl_train import FLTrainConfig
  6from kosmos.ml.datasets.iris_dataset import IrisDataset
  7from kosmos.simulator.fl_simulator import FLSimulator
  8from kosmos.topology.link import ClassicalLink, LinkId
  9from kosmos.topology.net import Network
 10from kosmos.topology.node import ClassicalNode, NodeId, NodeRole
 11
 12
 13def construct_network() -> Network:
 14    """Construct network topology with two classical clients and one classical server."""
 15    network = Network()
 16
 17    # Create classical nodes (clients and server)
 18    client_node_1 = ClassicalNode(
 19        id=NodeId("client_1"),
 20        roles=[NodeRole.END_USER],
 21    )
 22    client_node_2 = ClassicalNode(
 23        id=NodeId("client_2"),
 24        roles=[NodeRole.END_USER],
 25    )
 26    server_node = ClassicalNode(
 27        id=NodeId("server"),
 28        roles=[NodeRole.END_USER],
 29    )
 30
 31    # Create classical links connecting clients to the server
 32    client_server_link_1 = ClassicalLink(
 33        id=LinkId("c1_server"),
 34        src=client_node_1,
 35        dst=server_node,
 36        distance=1000.0,
 37        attenuation=0.0002,
 38        signal_speed=0.0002,
 39        bandwidth=10e9,
 40    )
 41    client_server_link_2 = ClassicalLink(
 42        id=LinkId("c2_server"),
 43        src=client_node_2,
 44        dst=server_node,
 45        distance=1000.0,
 46        attenuation=0.0002,
 47        signal_speed=0.0002,
 48        bandwidth=10e9,
 49    )
 50
 51    # Add nodes and links to the network
 52    network.add_node(client_node_1)
 53    network.add_node(client_node_2)
 54    network.add_node(server_node)
 55    network.add_link(client_server_link_1)
 56    network.add_link(client_server_link_2)
 57
 58    return network
 59
 60
 61def fl_iris_nn_example() -> None:
 62    """Run example of federated training and testing on the Iris dataset using a neural network."""
 63    network = construct_network()
 64
 65    # Dataset to train and test on
 66    dataset = IrisDataset()
 67
 68    # Model configuration that defines the neural network to use
 69    nn_config = NeuralNetworkConfig([16, 16])
 70
 71    # Configure federated learning
 72    train_config = FLTrainConfig(
 73        dataset=dataset,
 74        train_split=0.7,
 75        batch_size=8,
 76        num_epochs=5,
 77        model_config=nn_config,
 78        optimizer_config=AdamOptimizerConfig(lr=0.01),
 79        lr_scheduler_config=CosineLearningRateSchedulerConfig(max_epochs=5),
 80        max_grad_norm=1.0,
 81        loss_config=CrossEntropyLossConfig(),
 82        num_rounds=5,
 83    )
 84
 85    # Initialize simulator, which is responsible for running the federated learning experiment
 86    simulator = FLSimulator(
 87        network,
 88        train_config,
 89        client_nodes=["client_1", "client_2"],
 90        server_node="server",
 91        seed=1,
 92    )
 93
 94    # Run training
 95    for epoch_result in simulator.train():
 96        print(epoch_result)  # noqa: T201
 97
 98    # Evaluate trained model
 99    print(simulator.test())  # noqa: T201
100
101
102if __name__ == "__main__":
103    fl_iris_nn_example()