Source code for kosmos.ml.cl_manager
from collections.abc import Iterator
from kosmos.ml.config.sl_train import SLTrainConfig
from kosmos.ml.dataloader import make_train_test_dataloaders
from kosmos.ml.sl_result import SLTestIterationResult, SLTrainIterationResult
from kosmos.ml.sl_trainer import SLTrainer
[docs]
class CLManager:
"""Centralized learning manager for supervised learning classification tasks.
Attributes:
config (SLTrainConfig): Supervised learning training configuration.
dataset (SLDataset): The dataset used for training and testing.
model (Model): The model instance to be trained.
trainer (SLTrainer): The trainer instance managing the training process.
train_loader (DataLoader): DataLoader for the training data.
test_loader (DataLoader): DataLoader for the test data.
"""
def __init__(
self,
config: SLTrainConfig,
) -> None:
"""Initialize the centralized learning manager.
Args:
config: Supervised learning training configuration.
"""
self.config = config
self.dataset = self.config.dataset
self.model = self.config.model_config.get_instance(
self.dataset.input_dimension, self.dataset.output_dim
)
self.trainer = SLTrainer(
self.model,
self.config.optimizer_config,
self.config.lr_scheduler_config,
self.config.loss_config,
self.config.max_grad_norm,
)
train_loaders, self.test_loader = make_train_test_dataloaders(
self.dataset,
self.config.train_split,
self.config.batch_size,
)
self.train_loader = train_loaders[0]
[docs]
def train(self) -> Iterator[SLTrainIterationResult]:
"""Run training on the train data over the configured number of epochs.
Returns:
Iterator[SLTrainIterationResult]: An iterator yielding one train result per epoch.
"""
yield from self.trainer.train(self.config.num_epochs, self.train_loader)
[docs]
def test(self) -> SLTestIterationResult:
"""Evaluate the model on the test data.
Returns:
SLTestIterationResult: Result of the test iteration.
"""
return self.trainer.test(self.test_loader)