Source code for kosmos.ml.config.factories.encoding
from abc import ABC, abstractmethod
from typing import Literal
from kosmos.circuit_runner.typing import QuantumCircuitFramework
from kosmos.ml.models.vqc.encoding.encoding import AmplitudeEmbedding, AngleEmbedding, VQCEncoding
from kosmos.ml.models.vqc.encoding.pennylane_encoding import (
PennyLaneAmplitudeEmbedding,
PennyLaneAngleEmbedding,
)
from kosmos.ml.models.vqc.encoding.qiskit_encoding import (
QiskitAmplitudeEmbedding,
QiskitAngleEmbedding,
)
[docs]
class EncodingConfig(ABC):
"""Abstract base for encoding configurations.
Attributes:
framework (QuantumCircuitFramework | None): The framework for the encoding.
"""
def __init__(self) -> None:
"""Initialize the encoding configuration."""
self.framework = None
[docs]
def set_framework(self, framework: QuantumCircuitFramework) -> None:
"""Set the framework for the encoding.
Args:
framework (QuantumCircuitFramework): The framework to use for the encoding.
"""
self.framework = framework
def _validate_framework(self) -> None:
"""Validate the framework for the encoding."""
if self.framework is None:
msg = (
"Framework must be set for an encoding config via set_framework(...) "
"before getting an instance."
)
raise ValueError(msg)
[docs]
@abstractmethod
def get_instance(self, input_dim: int, output_dim: int) -> VQCEncoding:
"""Get the encoding instance.
Args:
input_dim (int): Model input dimension.
output_dim (int): Model output dimension.
Returns:
VQCEncoding: Encoding instance.
"""
[docs]
class AngleEmbeddingConfig(EncodingConfig):
"""Angle embedding configuration.
Attributes:
rotation (Literal["X", "Y", "Z"]): The rotation to use for the angle embedding.
"""
def __init__(self, rotation: Literal["X", "Y", "Z"] = "X") -> None:
"""Initialize the angle embedding configuration.
Args:
rotation (Literal["X", "Y", "Z"]): The rotation to use for the angle embedding.
Defaults to "X".
"""
super().__init__()
self.rotation = rotation
[docs]
def get_instance(self, input_dim: int, output_dim: int) -> AngleEmbedding:
"""Get the angle embedding instance.
Args:
input_dim (int): Model input dimension.
output_dim (int): Model output dimension.
Returns:
AngleEmbedding: Angle embedding instance.
"""
self._validate_framework()
angle_embedding_implementations: dict[str, type[AngleEmbedding]] = {
"pennylane": PennyLaneAngleEmbedding,
"qiskit": QiskitAngleEmbedding,
}
if self.framework not in angle_embedding_implementations:
msg = f"Unsupported framework: {self.framework}"
raise ValueError(msg)
cls = angle_embedding_implementations[self.framework]
return cls(input_dim, output_dim, self.rotation)
[docs]
class AmplitudeEmbeddingConfig(EncodingConfig):
"""Amplitude embedding configuration.
Attributes:
pad_with (complex): The input is padded with this constant to size :math:`2^n`.
normalize (bool): Whether to normalize the features.
"""
def __init__(
self,
pad_with: complex = 0.3,
*,
normalize: bool = True,
) -> None:
"""Initialize the amplitude embedding configuration.
Args:
pad_with (complex): The input is padded with this constant to size :math:`2^n`.
normalize (bool): Whether to normalize the features. Defaults to True.
"""
super().__init__()
self.pad_with = pad_with
self.normalize = normalize
[docs]
def get_instance(self, input_dim: int, output_dim: int) -> AmplitudeEmbedding:
"""Get the amplitude embedding instance.
Args:
input_dim (int): Model input dimension.
output_dim (int): Model output dimension.
Returns:
AmplitudeEmbedding: Amplitude embedding instance.
"""
self._validate_framework()
amplitude_embedding_implementations: dict[str, type[AmplitudeEmbedding]] = {
"pennylane": PennyLaneAmplitudeEmbedding,
"qiskit": QiskitAmplitudeEmbedding,
}
if self.framework not in amplitude_embedding_implementations:
msg = f"Unsupported framework: {self.framework}"
raise ValueError(msg)
cls = amplitude_embedding_implementations[self.framework]
return cls(input_dim, output_dim, self.pad_with, normalize=self.normalize)