Source code for fedsim.distributed.data_management.utils

"""
Data Management Utils
---------------------
"""

import numpy as np
from torch.utils import data


[docs]class Subset(data.Dataset): r"""Subset of a dataset at specified indices. Args: dataset (Dataset): The whole Dataset indices (sequence): Indices in the whole set selected for subset. """ def __init__(self, dataset, indices, transform=None): self.dataset = dataset if isinstance(indices, int) and indices == -1: self.indices = range(len(dataset)) else: self.indices = indices self.transform = transform targets = np.array(dataset.targets) self.targets = targets[self.indices] # remove the transform function of the original dataset if transform # is provided avoiding double transform if transform is not None and self.dataset.transform is not None: self.dataset.transform = None def __getitem__(self, idx): if isinstance(idx, list): x, y = self.dataset[[self.indices[i] for i in idx]] if self.transform is None: return x, y return self.transform(x), y x, y = self.dataset[self.indices[idx]] if self.transform is None: return x, y return self.transform(x), y def __len__(self): return len(self.indices)