Source code for kosmos.ml.fl.fl_manager
from collections.abc import Iterator
from kosmos.ml.config.sl_train import FLTrainConfig
from kosmos.ml.dataloader import make_train_test_dataloaders
from kosmos.ml.fl.fl_client import FLClient
from kosmos.ml.fl.fl_server import FLServer
from kosmos.ml.sl_result import SLTestIterationResult, SLTrainIterationResult
from kosmos.ml.sl_trainer import SLTrainer
from kosmos.topology.node import Node
[docs]
class FLManager:
"""Federated learning manager for supervised learning classification tasks.
Attributes:
config (FLTrainConfig): Federated learning training configuration.
client_nodes (list[Node]): The nodes representing federated clients.
server_node (Node): The node representing the federated server.
dataset (SLDataset): The dataset used for training and testing.
num_rounds (int): Number of federated learning rounds.
num_epochs (int): Number of training epochs per round.
train_loaders (list[DataLoader]): DataLoaders for the training data, one per client.
test_loader (DataLoader): DataLoader for the test data.
clients (list[FLClient] | None): The federated learning client instances.
server (FLServer | None): The federated learning server instance.
"""
def __init__(
self,
config: FLTrainConfig,
client_nodes: list[Node],
server_node: Node,
) -> None:
"""Initialize the federated learning manager.
Args:
config (FLTrainConfig): Federated learning training configuration.
client_nodes (list[Node]): The nodes representing federated clients.
server_node (Node): The node representing the federated server.
"""
self.config = config
self.client_nodes = client_nodes
self.server_node = server_node
self.dataset = config.dataset
self.num_rounds = config.num_rounds
self.num_epochs = config.num_epochs
self.train_loaders, self.test_loader = make_train_test_dataloaders(
self.dataset,
self.config.train_split,
self.config.batch_size,
num_train_subsets=len(self.client_nodes),
)
self.clients: list[FLClient] | None = None
self.server: FLServer | None = None
self._init_clients()
self._init_server()
def _get_trainer_instance(self) -> SLTrainer:
"""Create a new trainer based on the configuration.
Returns:
SLTrainer: The trainer instance.
"""
model = self.config.model_config.get_instance(
self.dataset.input_dimension, self.dataset.output_dim
)
return SLTrainer(
model,
self.config.optimizer_config,
self.config.lr_scheduler_config,
self.config.loss_config,
self.config.max_grad_norm,
)
def _init_clients(self) -> None:
"""Initialize the clients."""
clients: list[FLClient] = []
for client_node in self.client_nodes:
trainer = self._get_trainer_instance()
client = FLClient(trainer, client_node)
clients.append(client)
self.clients = clients
def _init_server(self) -> None:
"""Initialize the server."""
trainer = self._get_trainer_instance()
self.server = FLServer(trainer, self.server_node)
[docs]
def train(self) -> Iterator[SLTrainIterationResult]:
"""Run federated training across all configured rounds.
Returns:
Iterator[SLTrainIterationResult]: An iterator yielding one training result per epoch
for all rounds.
"""
for fl_round in range(self.num_rounds):
yield from self._run_round(fl_round)
[docs]
def test(self) -> SLTestIterationResult:
"""Evaluate the global model on the test dataset.
Returns:
SLTrainIterationResult: The result of the global model evaluation.
"""
return self.server.test(self.test_loader)
def _run_round(self, fl_round: int) -> Iterator[SLTrainIterationResult]:
"""Run a single round of federated learning.
Args:
fl_round (int): The index of the federated learning round.
Returns:
Iterator[SLTrainIterationResult]: An iterator yielding one training result per epoch.
"""
server_state = self.server.get_model_state()
client_states = []
for i, client in enumerate(self.clients):
client.set_model_state(server_state)
yield from client.train(self.num_epochs, self.train_loaders[i], fl_round)
client_states.append(client.get_model_state())
self.server.aggregate(client_states)