Source code for kosmos.ml.datasets.digits_dataset
import numpy as np
from sklearn.datasets import load_digits
from kosmos.ml.datasets.dataset import SLDataset
[docs]
class DigitsDataset(SLDataset):
"""Digits dataset for multiclass classification.
Notes:
- Number of instances: 1797
- Number of features: 64 numeric
- Classes: 10 (roughly balanced, digits 0-9)
References:
- Scikit-Learn: Digits dataset (8x8 images of handwritten digits) https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html
"""
def __init__(self, *, min_max_scaler: bool = True) -> None:
"""Initialize the dataset.
Args:
min_max_scaler (bool): Whether to apply min-max scaling to the features.
Defaults to True.
"""
ds = load_digits()
x = ds.data.astype(np.float32)
y = ds.target.astype(np.int64)
super().__init__(x, y, min_max_scaler=min_max_scaler)
@property
def class_names(self) -> list[str]:
"""Return human-readable class labels of numbers 0-9."""
return [str(i) for i in range(10)]