Source code for kosmos.ml.datasets.mnist_dataset

import gzip
from importlib.resources import as_file, files

import numpy as np

from kosmos.ml.datasets.dataset import SLDataset


[docs] class MNISTDataset(SLDataset): """MNIST handwritten digits dataset for multiclass classification. Notes: - Number of instances: 70,000 (60,000 train + 10,000 test) - Number of features: 784 numeric (28x28 pixel images, flattened) - Classes: 10 (digits 0-9, roughly balanced) References: - OpenML — MNIST dataset: https://www.openml.org/d/554 """ def __init__(self, *, min_max_scaler: bool = True) -> None: """Initialize the dataset. Args: min_max_scaler (bool): Whether to apply min-max scaling to the features. Defaults to True. """ path = files("kosmos.ml.datasets.data") / "mnist.data.gz" with as_file(path) as p, gzip.open(p, mode="r") as f: data = np.loadtxt((line.decode("utf-8") for line in f), delimiter=",") x = data[:, :-1].astype(np.float32, copy=False) y = data[:, -1].astype(np.int64, copy=False) super().__init__(x, y, min_max_scaler=min_max_scaler) @property def class_names(self) -> list[str]: """Return human-readable class labels of numbers 0-9.""" return [str(i) for i in range(10)]