Source code for kosmos.ml.sl_result
from abc import ABC
from dataclasses import dataclass
from kosmos.ml.sl_metrics import SLMetrics
from kosmos.topology.node import Node
def _format_node_id_value(value: str, length: int) -> str:
"""Format a node ID value to a fixed length.
Args:
value (str): The node ID value.
length (int): The fixed length.
"""
if len(value) > length:
return value[: length - 3] + "..."
return value.ljust(length)
[docs]
@dataclass
class SLIterationResult(ABC):
"""Result of a supervised learning iteration.
Attributes:
loss (float): Mean loss value over the iteration.
metrics (SLMetrics): Evaluation metrics for the iteration.
"""
loss: float
metrics: SLMetrics
def __str__(self) -> str:
"""Return a human-readable string with the loss and metrics formatted to four decimals."""
return f"Loss: {self.loss:.4f} | {self.metrics}"
[docs]
@dataclass
class SLTrainIterationResult(SLIterationResult):
"""Result of a supervised learning training iteration.
Attributes:
loss (float): Mean loss value over the iteration.
metrics (SLMetrics): Evaluation metrics for the iteration.
epoch (int): Epoch index corresponding to this iteration.
fl_round (int | None): Federated learning round index corresponding to this iteration.
Defaults to None.
node (Node | None): The node corresponding to this iteration. Defaults to None.
"""
epoch: int
fl_round: int | None = None
node: Node | None = None
def __str__(self) -> str:
"""Return a human-readable string with the loss and metrics formatted to four decimals."""
desc = ""
if self.node is not None:
formatted_id = _format_node_id_value(self.node.id.value, 10)
desc += f"{formatted_id} | "
if self.fl_round is not None:
desc += f"Round {self.fl_round:4d}".ljust(10) + " | "
desc += f"Epoch {self.epoch:4d}".ljust(10) + " | " + super().__str__()
return desc
[docs]
@dataclass
class SLTestIterationResult(SLIterationResult):
"""Result of a supervised learning test iteration.
Attributes:
loss (float): Mean loss value over the iteration.
metrics (SLMetrics): Evaluation metrics for the iteration.
"""
def __str__(self) -> str:
"""Return a human-readable string with the loss and metrics formatted to four decimals."""
return "Test".ljust(10) + " | " + super().__str__()