Source code for kosmos.ml.dataloader

from torch.utils.data import DataLoader, random_split

from kosmos.ml.datasets.dataset import SLDataset


[docs] def make_train_test_dataloaders( dataset: SLDataset, train_split: float, batch_size: int, num_train_subsets: int = 1, ) -> tuple[list[DataLoader], DataLoader]: """Split a dataset into training and test subsets and wrap them in DataLoaders. The training loaders shuffle their subset each epoch; the test loader does not. Args: dataset: The dataset to be split. train_split: Fraction of dataset for training. Test split is 1 - train_split. batch_size: Number of samples per batch in both loaders. num_train_subsets: Number of partitions to split the training subset into. Defaults to 1. Returns: tuple[list[DataLoader], DataLoader]: A tuple containing: - list[DataLoader]: DataLoaders for the partitions of the training subset. - DataLoader: DataLoader for the test subset. """ if num_train_subsets < 1: msg = "num_train_subsets must be >= 1." raise ValueError(msg) n_total = len(dataset) n_train = int(n_total * train_split) n_test = n_total - n_train if n_train == 0 or n_test == 0: msg = "Empty subset after splitting." raise ValueError(msg) # Train/test split train_subset, test_subset = random_split(dataset, lengths=[n_train, n_test]) # Split train subset across num_train_subsets num_train_samples = len(train_subset) base = num_train_samples // num_train_subsets rem = num_train_samples % num_train_subsets sizes = [(base + 1 if i < rem else base) for i in range(num_train_subsets)] if any(s == 0 for s in sizes): msg = ( f"Not enough training samples ({num_train_samples})" f" to create {num_train_subsets} training subsets." ) raise ValueError(msg) client_subsets = random_split(train_subset, lengths=sizes) train_loaders = [DataLoader(cs, batch_size, shuffle=True) for cs in client_subsets] test_loader = DataLoader(test_subset, batch_size, shuffle=False) return train_loaders, test_loader