Source code for kosmos.ml.config.factories.loss
from abc import ABC, abstractmethod
from torch.nn import CrossEntropyLoss, Module
[docs]
class LossConfig(ABC):
"""Loss function configuration for training."""
[docs]
@abstractmethod
def get_instance(self) -> Module:
"""Get the loss module instance.
Returns:
Module: Loss module instance.
"""
[docs]
class CrossEntropyLossConfig(LossConfig):
"""Cross-entropy loss function configuration."""
[docs]
def get_instance(self) -> CrossEntropyLoss:
"""Get the cross-entropy loss instance.
Returns:
CrossEntropyLoss: Cross-entropy loss instance.
"""
return CrossEntropyLoss()