Source code for kosmos.ml.config.factories.optimizer
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import Any
import torch
from torch.optim import SGD, Adam, Optimizer
type ParamsT = (
Iterable[torch.Tensor] | Iterable[dict[str, Any]] | Iterable[tuple[str, torch.Tensor]]
)
[docs]
class OptimizerConfig(ABC):
"""Optimizer configuration."""
[docs]
@abstractmethod
def get_instance(self, params: ParamsT) -> Optimizer:
"""Get the optimizer instance.
Args:
params (ParamsT): Parameters to optimize.
Returns:
Optimizer: Optimizer instance.
"""
[docs]
class SGDOptimizerConfig(OptimizerConfig):
"""Stochastic gradient descent (SGD) optimizer configuration.
Attributes:
lr (float): Learning rate.
momentum (float): Momentum factor.
weight_decay (float): Weight decay.
nesterov (bool): Whether to use Nesterov momentum.
"""
def __init__(
self,
lr: float = 1e-3,
momentum: float = 0.0,
weight_decay: float = 0.0,
*,
nesterov: bool = False,
) -> None:
"""Initialize the SGD optimizer configuration.
Args:
lr (float): Learning rate. Defaults to 1e-3.
momentum (float): Momentum factor. Defaults to 0.0.
weight_decay (float): Weight decay. Defaults to 0.0.
nesterov (bool): Whether to use Nesterov momentum. Only applicable
when momentum is non-zero. Defaults to False.
"""
self.lr = lr
self.momentum = momentum
self.weight_decay = weight_decay
self.nesterov = nesterov
[docs]
def get_instance(self, params: ParamsT) -> SGD:
"""Get the SGD optimizer instance.
Args:
params (ParamsT): Parameters to optimize.
Returns:
SGD: SGD optimizer instance.
"""
return SGD(
params,
lr=self.lr,
momentum=self.momentum,
weight_decay=self.weight_decay,
nesterov=self.nesterov,
)
[docs]
class AdamOptimizerConfig(OptimizerConfig):
"""Adam optimizer configuration.
Attributes:
lr (float): Learning rate.
weight_decay (float): Weight decay.
"""
def __init__(self, lr: float = 1e-3, weight_decay: float = 0.0) -> None:
"""Initialize the Adam optimizer configuration.
Args:
lr (float): Learning rate. Defaults to 1e-3.
weight_decay (float): Weight decay. Defaults to 0.0.
"""
self.lr = lr
self.weight_decay = weight_decay
[docs]
def get_instance(self, params: ParamsT) -> Adam:
"""Get the Adam optimizer instance.
Args:
params (ParamsT): Parameters to optimize.
Returns:
Adam: Adam optimizer instance.
"""
return Adam(params, lr=self.lr, weight_decay=self.weight_decay)