Source code for kosmos.ml.sl_trainer
from collections.abc import Iterator
import numpy as np
import torch
from torch.utils.data import DataLoader
from kosmos.ml.config.factories.loss import LossConfig
from kosmos.ml.config.factories.lr_scheduler import LearningRateSchedulerConfig
from kosmos.ml.config.factories.optimizer import OptimizerConfig
from kosmos.ml.models.model import Model
from kosmos.ml.sl_metrics import calculate_sl_metrics
from kosmos.ml.sl_result import SLTestIterationResult, SLTrainIterationResult
from kosmos.topology.node import Node
DEVICE = "cpu"
[docs]
class SLTrainer:
"""Trainer for supervised learning classification tasks.
Attributes:
device (str): The PyTorch device to use for training.
model (Model): The model to train/evaluate.
optimizer (Optimizer): The optimizer instance for training.
lr_scheduler (LRScheduler | None): The learning rate scheduler instance.
criterion (Module): The loss function used for training.
max_grad_norm (float | None): Maximum gradient norm for gradient clipping.
"""
def __init__(
self,
model: Model,
optimizer_config: OptimizerConfig,
lr_scheduler_config: LearningRateSchedulerConfig | None,
loss_config: LossConfig,
max_grad_norm: float | None,
) -> None:
"""Initialize a supervised learning trainer.
Args:
model (Model): The model to train/evaluate.
optimizer_config (OptimizerConfig): Optimizer configuration.
lr_scheduler_config (LearningRateSchedulerConfig): Learning rate scheduler
configuration.
loss_config (LossConfig): Loss configuration.
max_grad_norm (float | None): Maximum gradient norm for gradient clipping.
"""
self.device = DEVICE
self.model = model.to(self.device)
self.optimizer = optimizer_config.get_instance(self.model.parameters())
if lr_scheduler_config is not None:
self.lr_scheduler = lr_scheduler_config.get_instance(self.optimizer)
else:
self.lr_scheduler = None
self.criterion = loss_config.get_instance()
self.max_grad_norm = max_grad_norm
[docs]
def train(
self,
num_epochs: int,
dataloader: DataLoader,
fl_round: int | None = None,
node: Node | None = None,
) -> Iterator[SLTrainIterationResult]:
"""Train the model on the given train data.
Args:
num_epochs (int): Number of epochs to run.
dataloader (DataLoader): DataLoader providing the train data.
fl_round (int | None): Federated learning round index to attach to the results.
Defaults to None.
node (Node | None): Node corresponding to this iteration to attach to the results.
Defaults to None.
Returns:
Iterator[SLTrainIterationResult]: An iterator yielding one train result per epoch.
"""
for epoch in range(num_epochs):
result = self._train_epoch(epoch, dataloader)
result.fl_round = fl_round
result.node = node
yield result
[docs]
def test(self, dataloader: DataLoader) -> SLTestIterationResult:
"""Evaluate the model on the given test dataloader.
Args:
dataloader (DataLoader): DataLoader providing the test data.
Returns:
SLTestIterationResult: Result of the test iteration.
"""
return self._evaluate(dataloader)
def _train_epoch(self, epoch: int, dataloader: DataLoader) -> SLTrainIterationResult:
"""Run a single training epoch.
Args:
epoch (int): The index of the epoch.
dataloader (DataLoader): DataLoader providing the train data.
Returns:
SLTrainIterationResult: Result of the training iteration.
"""
self.model.train()
total_loss = 0.0
all_preds, all_targets = [], []
for batch in dataloader:
x, y = batch
x = x.to(self.device)
y = y.to(self.device)
self.optimizer.zero_grad(set_to_none=True)
y_pred = self.model(x)
batch_loss = self.criterion(y_pred, y.long())
batch_loss.backward()
if self.max_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optimizer.step()
total_loss += batch_loss.item()
# Collect predictions and targets for metrics
with torch.no_grad():
all_preds.extend(y_pred.argmax(-1).cpu().tolist())
all_targets.extend(y.long().cpu().tolist())
if self.lr_scheduler:
self.lr_scheduler.step()
avg_loss = total_loss / len(dataloader)
metrics = calculate_sl_metrics(
y_true=np.asarray(all_targets),
y_pred=np.asarray(all_preds),
)
return SLTrainIterationResult(avg_loss, metrics, epoch)
def _evaluate(self, dataloader: DataLoader) -> SLTestIterationResult:
"""Run model evaluation.
Args:
dataloader (DataLoader): DataLoader providing the test data.
Returns:
SLTestIterationResult: Result of the test iteration.
"""
self.model.eval()
total_loss = 0.0
all_preds, all_targets = [], []
with torch.inference_mode():
for batch in dataloader:
x, y = batch
x = x.to(self.device)
y = y.to(self.device)
y_pred = self.model(x)
batch_loss = self.criterion(y_pred, y.long())
total_loss += batch_loss.item()
# Collect predictions and targets for metrics
all_preds.extend(y_pred.argmax(-1).cpu().tolist())
all_targets.extend(y.long().cpu().tolist())
avg_loss = total_loss / len(dataloader)
metrics = calculate_sl_metrics(
y_true=np.asarray(all_targets),
y_pred=np.asarray(all_preds),
)
return SLTestIterationResult(avg_loss, metrics)