Data Manager#

class DataManager(root, seed, save_dir=None)[source]#

DataManager base class. Any other Data Manager is inherited from this class. There are four abstract class methods that child classes should implement: get_identifiers, make_datasets, make_transforms, partition_local_data.

Warning

when inheritted, super should be called at the end of the constructor because the abstract classes are called in super's constructor!

Parameters
  • root (str) -- root dir of the dataset to partition

  • seed (int) -- random seed of partitioning

  • save_dir (str, optional) -- path to save partitioned indices.

get_global_dataset() Dict[str, torch.utils.data.dataset.Dataset][source]#

returns the global dataset

Returns

Dict[str, Dataset] -- global dataset for each split

get_global_splits_names()[source]#

returns name of the global splits (train, test, etc.)

Returns

List[str] -- list of global split names

get_group_dataset(ids: Iterable[int]) Dict[str, torch.utils.data.dataset.Dataset][source]#

returns the local dataset corresponding to a group of given partition ids

Parameters

ids (Iterable[int]) -- a list or tuple of partition ids

Returns

Dict[str, Dataset] -- a mapping of split_name: dataset

get_identifiers() Sequence[str][source]#

Returns identifiers to be used for saving the partition info.

Raises

NotImplementedError -- this abstract method should be implemented by child classes

Returns

Sequence[str] -- a sequence of str identifing class instance

get_local_dataset(id: int) Dict[str, torch.utils.data.dataset.Dataset][source]#

returns the local dataset corresponding to a given partition id

Parameters

id (int) -- partition id

Returns

Dict[str, Dataset] -- a mapping of split_name: dataset

get_local_splits_names()[source]#

returns name of the local splits (train, test, etc.)

Returns

List[str] -- list of local split names

get_oracle_dataset() Dict[str, torch.utils.data.dataset.Dataset][source]#

returns all of the local datasets stacked up.

Returns

Dict[str, Dataset] -- Oracle dataset for each split

get_partitioning_name() str[source]#

returns unique name of the DataManager instance. .. note:: This method can help store and retrieval of the partitioning indices, so the experiments could reproduced on a machine.

Returns

str -- a unique name for the DataManager instance.

make_datasets(root: str) Tuple[object, object][source]#

makes and returns local and global dataset objects. The created datasets do not need a transform as recompiled datasets with separately provided transforms on the fly.

Parameters
  • dataset_name (str) -- name of the dataset.

  • root (str) -- directory to download and manipulate data.

Raises

NotImplementedError -- this abstract method should be implemented by child classes

Returns

Tuple[object, object] -- local and global dataset

make_transforms() Tuple[object, object][source]#

make and return the dataset trasformations for local and global split.

Raises

NotImplementedError -- this abstract method should be implemented by child classes

Returns

Tuple[Dict[str, Callable], Dict[str, Callable]] --

tuple of two dictionaries,

first, the local transform mapping and second the global transform mapping.

partition_global_data(dataset: object) Dict[str, Iterable[int]][source]#

partitions global data indices into splits (e.g., train, test, ...).

Parameters

dataset (object) -- global dataset

Returns

Dict[str, Iterable[int]] -- dictionary of {split:example indices of global dataset}.

partition_local_data(dataset: object) Dict[str, Iterable[Iterable[int]]][source]#

partitions local data indices into client-indexed Iterable.

Parameters

dataset (object) -- local dataset

Raises

NotImplementedError -- this abstract method should be implemented by child classes

Returns

Dict[str, Iterable[Iterable[int]]] -- dictionary of {split:client-indexed iterables of example indices}.