kosmos.ml.fl.fl_client

Classes

class FLClient(trainer: kosmos.ml.sl_trainer.SLTrainer, node: kosmos.topology.node.Node)

Federated learning client.

Initialize a federated learning client.

Parameters:
  • trainer (SLTrainer) – The trainer used for local training.

  • node (Node) – The node representing this client in the topology.


Methods

get_model_state() dict

Get the current model state of the client.

Returns:

A state_dict containing the model parameters.

Return type:

dict

set_model_state(state_dict: dict) None

Set the client’s model state.

Parameters:

state_dict (dict) – The state_dict to set the client’s model state to.

train(num_epochs: int, dataloader: torch.utils.data.DataLoader, fl_round: int) collections.abc.Iterator[kosmos.ml.sl_result.SLTrainIterationResult]

Train the client’s model on the given data.

Parameters:
  • num_epochs (int) – Number of epochs to run.

  • dataloader (DataLoader) – Dataloader providing the local training data.

  • fl_round (int) – The federated learning round index for this training run.

Returns:

An iterator yielding one training result per epoch.

Return type:

Iterator[SLTrainIterationResult]