Source code for fedsim.distributed.data_management.basic_data_manager

r"""
A Basic Data Manager
--------------------
"""
import numpy as np
import torchvision
from sklearn.model_selection import train_test_split
from torchvision.datasets import CIFAR10
from torchvision.datasets import CIFAR100
from torchvision.datasets import MNIST
from tqdm import tqdm

from .data_manager import DataManager


[docs]class BasicDataManager(DataManager): r"""A basic data manager for partitioning the data. Currecntly three rules of partitioning are supported: - iid: same label distribution among clients. sample balance determines quota of each client samples from a lognorm distribution. - dir: Dirichlete distribution with concentration parameter given by label_balance determines label balance of each client. sample balance determines quota of each client samples from a lognorm distribution. - exclusive: samples corresponding to each label are randomly splitted to k clients where k = total_sample_size * label_balance. sample_balance determines the way this split happens (quota). This rule also is know as "shards splitting". Args: root (str): root dir of the dataset to partition dataset (str): name of the dataset num_clients (int): number of partitions or clients rule (str): rule of partitioning sample_balance (float): balance of number of samples among clients label_balance (float): balance of the labels on each clietns local_test_portion (float): portion of local test set from trian global_valid_portion (float): portion of global valid split. What remains from global samples goes to the test split. seed (int): random seed of partitioning save_dir (str, optional): dir to save partitioned indices. """ def __init__( self, root="data", dataset="mnist", num_partitions=500, rule="iid", sample_balance=0.0, label_balance=1.0, local_test_portion=0.0, global_valid_portion=0.0, seed=10, save_dir="partitions", ): self.dataset_name = dataset self.num_partitions = num_partitions self.rule = rule self.sample_balance = sample_balance self.label_balance = label_balance self.local_test_portion = local_test_portion self.global_valid_portion = global_valid_portion # super should be called at the end because abstract classes are # called in its __init__ super(BasicDataManager, self).__init__( root, seed, save_dir=save_dir, )
[docs] def make_datasets(self, root): """makes and returns local and global dataset objects. The created datasets do not need a transform as recompiled datasets with separately provided transforms on the fly. Args: dataset_name (str): name of the dataset. root (str): directory to download and manipulate data. Returns: Tuple[object, object]: local and global dataset """ if self.dataset_name == "mnist": local_dset = MNIST(root, download=True, train=True, transform=None) global_dset = MNIST(root, download=True, train=True, transform=None) elif self.dataset_name == "cifar10" or self.dataset_name == "cifar100": dst_class = CIFAR10 if self.dataset_name == "cifar10" else CIFAR100 local_dset = dst_class(root=root, download=True, train=True, transform=None) global_dset = dst_class( root=root, download=True, train=False, transform=None, ) else: raise NotImplementedError return local_dset, global_dset
[docs] def make_transforms(self): """make and return the dataset trasformations for local and global split. Returns: Tuple[Dict[str, Callable], Dict[str, Callable]]: tuple of two dictionaries, first, the local transform mapping and second the global transform mapping. """ if self.dataset_name == "mnist": train_transform = torchvision.transforms.Compose( [ torchvision.transforms.ToTensor(), ] ) infer_transform = train_transform if self.dataset_name == "cifar10" or self.dataset_name == "cifar100": train_transform = torchvision.transforms.Compose( [ torchvision.transforms.ToTensor(), torchvision.transforms.RandomCrop(24), torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.ColorJitter( brightness=(0.5, 1.5), contrast=(0.5, 1.5) ), ] ) infer_transform = torchvision.transforms.Compose( [ torchvision.transforms.ToTensor(), torchvision.transforms.CenterCrop(24), ] ) return ( dict(train=train_transform, test=infer_transform), # for local dict(test=infer_transform, valid=infer_transform), # for gloval )
[docs] def partition_local_data(self, dataset): """partitions local data indices into client-indexed Iterable. Args: dataset (object): local dataset Returns: Dict[str, Iterable[Iterable[int]]]: dictionary of {split:client-indexed iterables of example indices}. """ n = self.num_partitions targets = np.array(dataset.targets) all_sample_count = len(targets) num_classes = len(np.unique(targets)) # the special case of exclusive rule: if self.rule == "exclusive": # TODO: implement this raise NotImplementedError # # get number of samples per label # label_counts = [(targets==i).sum() for i in range(num_classes)] # for label, label_count in enumerate(label_counts): # # randomly select k clients # # determine the quota for each client from a lognorm # # reassign the # ********************************************************* # determine sample quota for each client sample_per_client = all_sample_count // n if self.sample_balance != 0: # Draw from lognormal distribution client_quota = np.random.lognormal( mean=np.log(sample_per_client), sigma=self.sample_balance, size=n, ) quota_sum = np.sum(client_quota) client_quota = (client_quota / quota_sum * all_sample_count).astype(int) diff = quota_sum - all_sample_count # Add/Sub the excess number starting from first client if diff != 0: for clnt_i in range(n): if client_quota[clnt_i] > diff: client_quota[clnt_i] -= diff break else: client_quota = np.ones(n, dtype=int) * sample_per_client indices = [np.zeros(client_quota[client], dtype=int) for client in range(n)] # ********************************************************* if self.rule == "dir": # Dirichlet partitioning rule cls_priors = np.random.dirichlet( alpha=[self.label_balance] * num_classes, size=n ) prior_cumsum = np.cumsum(cls_priors, axis=1) idx_list = [np.where(targets == i)[0] for i in range(num_classes)] cls_amount = np.array([len(idx_list[i]) for i in range(num_classes)]) print("partitionig") pbar = tqdm(total=np.sum(client_quota)) while np.sum(client_quota) != 0: curr_clnt = np.random.randint(n) # If current node is full resample a client if client_quota[curr_clnt] <= 0: continue client_quota[curr_clnt] -= 1 curr_prior = prior_cumsum[curr_clnt] # exclude the classes that have ran out of examples curr_prior[cls_amount <= 0] = -1 # scale the prior up so the positive values sum to # 1 again cpp = curr_prior[curr_prior > 0] cpp /= cpp.sum() curr_prior[curr_prior > 0] = cpp while True: if (curr_prior > 0).sum() < 1: raise Exception("choose another seed") if (curr_prior > 0).sum() == 1: cls_label = curr_prior.argmax() else: uu = np.random.uniform() cls_label = np.argmax(uu <= curr_prior) # Redraw class label if out of that class if cls_amount[cls_label] <= 0: continue cls_amount[cls_label] -= 1 indices[curr_clnt][client_quota[curr_clnt]] = idx_list[cls_label][ cls_amount[cls_label] ] break pbar.update(1) pbar.close() # ********************************************************* elif self.rule == "iid": clnt_quota_cum_sum = np.concatenate(([0], np.cumsum(client_quota))) for client_index in range(n): indices[client_index] = np.arange( clnt_quota_cum_sum[client_index], clnt_quota_cum_sum[client_index + 1], ) else: raise NotImplementedError ts_portion = self.local_test_portion if ts_portion > 0: new_indices = dict(train=[], test=[]) for client_indices in indices: train_idxs, test_idxs = train_test_split( client_indices, test_size=ts_portion ) new_indices["train"].append(train_idxs) new_indices["test"].append(test_idxs) else: new_indices = dict(train=indices) return new_indices
[docs] def partition_global_data(self, dataset): """partitions global data indices into splits (e.g., train, test, ...). Args: dataset (object): global dataset Returns: Dict[str, Iterable[int]]: dictionary of {split:example indices of global dataset}. """ num = len(dataset) if self.global_valid_portion > 0: val = int(num * self.global_valid_portion) return dict(test=range(val, num), valid=range(0, val)) return dict(test=range(num))
[docs] def get_identifiers(self): """Returns identifiers to be used for saving the partition info. Returns: Sequence[str]: a sequence of str identifing class instance """ identifiers = [ self.dataset_name, str(self.num_partitions), self.rule, ] if self.rule == "dir": identifiers.append(str(self.label_balance)) if self.sample_balance == 0: identifiers.append("balanced") else: identifiers.append(f"unbalanced_{self.sample_balance}") if self.local_test_portion > 0: identifiers.append("lTS_{}".format(self.local_test_portion)) if self.global_valid_portion > 0: identifiers.append("gVL_{}".format(self.global_valid_portion)) return identifiers