Source code for kosmos.ml.fl.fl_server
import torch
from torch.utils.data import DataLoader
from kosmos.ml.sl_result import SLTestIterationResult
from kosmos.ml.sl_trainer import SLTrainer
from kosmos.topology.node import Node
[docs]
class FLServer:
"""Federated learning server.
Attributes:
trainer (SLTrainer): The trainer used by this server.
node (Node): The node representing this server in the topology.
model (Model): The model instance associated with this server.
"""
def __init__(self, trainer: SLTrainer, node: Node) -> None:
"""Initialize a federated learning server.
Args:
trainer (SLTrainer): The trainer used by this server.
node (Node): The node representing this server in the topology.
"""
self.trainer = trainer
self.node = node
self.model = self.trainer.model
[docs]
def get_model_state(self) -> dict:
"""Get the current model state of the server.
Returns:
dict: A state_dict containing the model parameters.
"""
return self.model.state_dict()
[docs]
def aggregate(self, client_states: list[dict]) -> None:
"""Aggregate model states from clients using simple averaging.
Args:
client_states (list[dict]): A list of model state_dicts from clients.
"""
global_dict = self.model.state_dict()
for key in global_dict:
global_dict[key] = torch.stack(
[cs[key].detach().clone() for cs in client_states]
).mean(dim=0)
self.model.load_state_dict(global_dict)
[docs]
def test(self, dataloader: DataLoader) -> SLTestIterationResult:
"""Evaluate the global model on test data.
Args:
dataloader (DataLoader): Dataloader providing the test data.
Returns:
SLTestIterationResult: The result of the test iteration.
"""
return self.trainer.test(dataloader)