Source code for kosmos.ml.config.factories.loss

from abc import ABC, abstractmethod

from torch.nn import CrossEntropyLoss, Module


[docs] class LossConfig(ABC): """Loss function configuration for training."""
[docs] @abstractmethod def get_instance(self) -> Module: """Get the loss module instance. Returns: Module: Loss module instance. """
[docs] class CrossEntropyLossConfig(LossConfig): """Cross-entropy loss function configuration."""
[docs] def get_instance(self) -> CrossEntropyLoss: """Get the cross-entropy loss instance. Returns: CrossEntropyLoss: Cross-entropy loss instance. """ return CrossEntropyLoss()