Source code for kosmos.ml.fl.fl_client

from collections.abc import Iterator

from torch.utils.data import DataLoader

from kosmos.ml.sl_result import SLTrainIterationResult
from kosmos.ml.sl_trainer import SLTrainer
from kosmos.topology.node import Node


[docs] class FLClient: """Federated learning client. Attributes: trainer (SLTrainer): The trainer used for local training. node (Node): The node representing this client in the topology. model (Model): The model instance associated with this client. """ def __init__(self, trainer: SLTrainer, node: Node) -> None: """Initialize a federated learning client. Args: trainer (SLTrainer): The trainer used for local training. node (Node): The node representing this client in the topology. """ self.trainer = trainer self.node = node self.model = self.trainer.model
[docs] def get_model_state(self) -> dict: """Get the current model state of the client. Returns: dict: A state_dict containing the model parameters. """ return self.model.state_dict()
[docs] def set_model_state(self, state_dict: dict) -> None: """Set the client's model state. Args: state_dict (dict): The state_dict to set the client's model state to. """ self.model.load_state_dict(state_dict)
[docs] def train( self, num_epochs: int, dataloader: DataLoader, fl_round: int ) -> Iterator[SLTrainIterationResult]: """Train the client's model on the given data. Args: num_epochs (int): Number of epochs to run. dataloader (DataLoader): Dataloader providing the local training data. fl_round (int): The federated learning round index for this training run. Returns: Iterator[SLTrainIterationResult]: An iterator yielding one training result per epoch. """ yield from self.trainer.train(num_epochs, dataloader, fl_round, self.node)