Source code for kosmos.ml.config.factories.lr_scheduler
from abc import ABC, abstractmethod
from torch.optim import Optimizer
from torch.optim.lr_scheduler import CosineAnnealingLR, ExponentialLR, LRScheduler, StepLR
[docs]
class LearningRateSchedulerConfig(ABC):
"""Learning rate scheduler configuration."""
[docs]
@abstractmethod
def get_instance(self, optimizer: Optimizer) -> LRScheduler:
"""Get the learning rate scheduler instance.
Returns:
LRScheduler: Learning rate scheduler instance.
"""
[docs]
class StepLearningRateSchedulerConfig(LearningRateSchedulerConfig):
"""Step learning rate scheduler configuration.
Attributes:
step_size (int): Period of learning rate decay.
gamma (float): Multiplicative factor of learning rate decay.
"""
def __init__(self, step_size: int, gamma: float = 0.1) -> None:
"""Initialize the step learning rate scheduler configuration.
Args:
step_size (int): Period of learning rate decay.
gamma (float): Multiplicative factor of learning rate decay. Defaults to 0.1.
"""
self.step_size = step_size
self.gamma = gamma
[docs]
def get_instance(self, optimizer: Optimizer) -> StepLR:
"""Get the step learning rate scheduler instance.
Args:
optimizer (Optimizer): Optimizer instance.
Returns:
StepLR: Step learning rate scheduler instance.
"""
return StepLR(optimizer, step_size=self.step_size, gamma=self.gamma)
[docs]
class ExponentialLearningRateSchedulerConfig(LearningRateSchedulerConfig):
"""Exponential learning rate scheduler configuration.
Attributes:
gamma (float): Multiplicative factor of learning rate decay.
"""
def __init__(self, gamma: float) -> None:
"""Initialize the exponential learning rate scheduler configuration.
Args:
gamma (float): Multiplicative factor of learning rate decay.
"""
self.gamma = gamma
[docs]
def get_instance(self, optimizer: Optimizer) -> ExponentialLR:
"""Get the exponential learning rate scheduler instance.
Args:
optimizer (Optimizer): Optimizer instance.
Returns:
ExponentialLR: Exponential learning rate scheduler instance.
"""
return ExponentialLR(optimizer, gamma=self.gamma)
[docs]
class CosineLearningRateSchedulerConfig(LearningRateSchedulerConfig):
"""Cosine annealing learning rate scheduler configuration.
Attributes:
max_epochs (int): Maximum number of epochs (iterations for the scheduler).
min_lr (float): Minimum learning rate.
"""
def __init__(self, max_epochs: int, min_lr: float = 0.0) -> None:
"""Initialize the cosine learning rate scheduler configuration.
Args:
max_epochs (int): Maximum number of epochs (iterations for the scheduler).
min_lr (float): Minimum learning rate. Defaults to 0.0.
"""
self.max_epochs = max_epochs
self.min_lr = min_lr
[docs]
def get_instance(self, optimizer: Optimizer) -> CosineAnnealingLR:
"""Get the cosine learning rate scheduler instance.
Args:
optimizer (Optimizer): Optimizer instance.
Returns:
CosineAnnealingLR: Cosine learning rate scheduler instance.
"""
return CosineAnnealingLR(optimizer, T_max=self.max_epochs, eta_min=self.min_lr)