Source code for kosmos.ml.datasets.mnist_dataset
import gzip
from importlib.resources import as_file, files
import numpy as np
from kosmos.ml.datasets.dataset import SLDataset
[docs]
class MNISTDataset(SLDataset):
"""MNIST handwritten digits dataset for multiclass classification.
Notes:
- Number of instances: 70,000 (60,000 train + 10,000 test)
- Number of features: 784 numeric (28x28 pixel images, flattened)
- Classes: 10 (digits 0-9, roughly balanced)
References:
- OpenML — MNIST dataset: https://www.openml.org/d/554
"""
def __init__(self, *, min_max_scaler: bool = True) -> None:
"""Initialize the dataset.
Args:
min_max_scaler (bool): Whether to apply min-max scaling to the features.
Defaults to True.
"""
path = files("kosmos.ml.datasets.data") / "mnist.data.gz"
with as_file(path) as p, gzip.open(p, mode="r") as f:
data = np.loadtxt((line.decode("utf-8") for line in f), delimiter=",")
x = data[:, :-1].astype(np.float32, copy=False)
y = data[:, -1].astype(np.int64, copy=False)
super().__init__(x, y, min_max_scaler=min_max_scaler)
@property
def class_names(self) -> list[str]:
"""Return human-readable class labels of numbers 0-9."""
return [str(i) for i in range(10)]