Source code for kosmos.ml.config.sl_train

from dataclasses import dataclass

from kosmos.ml.config.factories.loss import LossConfig
from kosmos.ml.config.factories.lr_scheduler import LearningRateSchedulerConfig
from kosmos.ml.config.factories.model import ModelConfig
from kosmos.ml.config.factories.optimizer import OptimizerConfig
from kosmos.ml.datasets.dataset import SLDataset


[docs] @dataclass(frozen=True, kw_only=True) class SLTrainConfig: """Supervised learning training configuration. Attributes: dataset (SLDataset): Supervised learning dataset. train_split (float): Fraction of dataset for training. Test split is 1 - train_split. batch_size (int): Number of samples per batch. num_epochs (int): Number of training epochs. model_config (ModelConfig): Model configuration. optimizer_config (OptimizerConfig): Optimizer configuration. lr_scheduler_config (LearningRateSchedulerConfig | None): Learning rate scheduler configuration. Defaults to None. max_grad_norm (float | None): Maximum gradient norm. Defaults to 1.0. loss_config (LossConfig): Loss function configuration. """ dataset: SLDataset train_split: float batch_size: int num_epochs: int model_config: ModelConfig optimizer_config: OptimizerConfig lr_scheduler_config: LearningRateSchedulerConfig | None = None max_grad_norm: float | None = 1.0 loss_config: LossConfig def __post_init__(self) -> None: """Validate the train configuration.""" if not (0.0 < self.train_split < 1.0): msg = "train_split must be in (0,1)." raise ValueError(msg) if self.batch_size <= 0: msg = "batch_size must be > 0." raise ValueError(msg) if self.num_epochs <= 0: msg = "num_epochs must be > 0." raise ValueError(msg)
[docs] @dataclass(frozen=True, kw_only=True) class FLTrainConfig(SLTrainConfig): """Federated learning training configuration. Attributes: dataset (SLDataset): Supervised learning dataset. train_split (float): Fraction of dataset for training. Test split is 1 - train_split. batch_size (int): Number of samples per batch. num_epochs (int): Number of training epochs. model_config (ModelConfig): Model configuration. optimizer_config (OptimizerConfig): Optimizer configuration. lr_scheduler_config (LearningRateSchedulerConfig | None): Learning rate scheduler configuration. Defaults to None. max_grad_norm (float | None): Maximum gradient norm. Defaults to 1.0. loss_config (LossConfig): Loss function configuration. num_rounds (int): Number of federated learning rounds. """ num_rounds: int def __post_init__(self) -> None: """Validate the train configuration.""" super().__post_init__() if self.num_rounds <= 0: msg = "num_rounds must be > 0." raise ValueError(msg)