:py:mod:`kosmos.ml.sl_trainer` ============================== .. py:module:: kosmos.ml.sl_trainer Module Attributes ----------------- .. py:data:: DEVICE :value: 'cpu' Classes ------- .. py:class:: SLTrainer(model: kosmos.ml.models.model.Model, optimizer_config: kosmos.ml.config.factories.optimizer.OptimizerConfig, lr_scheduler_config: kosmos.ml.config.factories.lr_scheduler.LearningRateSchedulerConfig | None, loss_config: kosmos.ml.config.factories.loss.LossConfig, max_grad_norm: float | None) Trainer for supervised learning classification tasks. Initialize a supervised learning trainer. :param model: The model to train/evaluate. :type model: Model :param optimizer_config: Optimizer configuration. :type optimizer_config: OptimizerConfig :param lr_scheduler_config: Learning rate scheduler configuration. :type lr_scheduler_config: LearningRateSchedulerConfig :param loss_config: Loss configuration. :type loss_config: LossConfig :param max_grad_norm: Maximum gradient norm. :type max_grad_norm: float | None | .. rubric:: Methods .. py:method:: train(num_epochs: int, dataloader: torch.utils.data.DataLoader, fl_round: int | None = None, node: kosmos.topology.node.Node | None = None) -> collections.abc.Iterator[kosmos.ml.sl_result.SLTrainIterationResult] Train the model on the given train data. :param num_epochs: Number of epochs to run. :type num_epochs: int :param dataloader: DataLoader providing the train data. :type dataloader: DataLoader :param fl_round: Federated learning round index to attach to the results. Defaults to None. :type fl_round: int | None :param node: Node corresponding to this iteration to attach to the results. Defaults to None. :type node: Node | None :returns: An iterator yielding one train result per epoch. :rtype: Iterator[SLTrainIterationResult] .. py:method:: test(dataloader: torch.utils.data.DataLoader) -> kosmos.ml.sl_result.SLTestIterationResult Evaluate the model on the given test dataloader. :param dataloader: DataLoader providing the test data. :type dataloader: DataLoader :returns: Result of the test iteration. :rtype: SLTestIterationResult