Source code for fedsim.distributed.centralized.centralized_fl_algorithm

r"""
Centralized Federated Learnming Algorithm
-----------------------------------------
"""
import inspect
import random
from functools import partial
from typing import Any
from typing import Callable
from typing import Dict
from typing import Hashable
from typing import Iterable
from typing import Mapping
from typing import Optional
from typing import Union

import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import trange

from fedsim import scores
from fedsim.utils import AppendixAggregator
from fedsim.utils import SerialAggregator
from fedsim.utils import Storage
from fedsim.utils import apply_on_dict
from fedsim.utils import get_from_module


[docs]class CentralFLAlgorithm(object): r"""Base class for centralized FL algorithm. Args: data_manager (``distributed.data_management.DataManager``): data manager metric_logger (``logall.Logger``): metric logger for tracking. num_clients (int): number of clients sample_scheme (``str``): mode of sampling clients. Options are ``'uniform'`` and ``'sequential'`` sample_rate (``float``): rate of sampling clients model_def (``torch.Module``): definition of for constructing the model epochs (``int``): number of local epochs criterion_def (``Callable``): loss function defining local objective optimizer_def (``Callable``): derfintion of server optimizer local_optimizer_def (``Callable``): defintoin of local optimizer lr_scheduler_def (``Callable``): definition of lr scheduler of server optimizer. local_lr_scheduler_def (``Callable``): definition of lr scheduler of local optimizer r2r_local_lr_scheduler_def (``Callable``): definition to schedule lr that is delivered to the clients at each round (deterimined init lr of the client optimizer) batch_size (int): batch size of the local trianing test_batch_size (int): inference time batch size device (str): cpu, cuda, or gpu number .. note:: definition of * learning rate schedulers, could be any of the ones defined at ``torch.optim.lr_scheduler`` or any other that implements step and get_last_lr methods. * optimizers, could be any ``torch.optim.Optimizer``. * model, could be any ``torch.Module``. * criterion, could be any ``fedsim.losses``. Architecture: .. image:: ../_static/arch.svg """ def __init__( self, data_manager, metric_logger, num_clients, sample_scheme, sample_rate, model_def, epochs, criterion_def, optimizer_def=partial(torch.optim.SGD, lr=1.0), local_optimizer_def=partial(torch.optim.SGD, lr=0.1), lr_scheduler_def=None, local_lr_scheduler_def=None, r2r_local_lr_scheduler_def=None, batch_size=32, test_batch_size=64, device="cpu", *args, **kwargs, ): sample_count = int(sample_rate * num_clients) if not 1 <= sample_count <= num_clients: raise Exception( "invalid client sample size for {}% of {} clients".format( sample_rate, num_clients ) ) # support functools.partial as model_def if hasattr(model_def, "func"): model_def_ = getattr(model_def, "func") else: model_def_ = model_def if isinstance(model_def_, str): model_def = get_from_module("fedsim.models", model_def) elif issubclass(model_def_, nn.Module): model_def = model_def else: raise Exception("incompatiple model!") if isinstance(criterion_def, str) and hasattr(scores, criterion_def): criterion_def = getattr(scores, criterion_def) else: criterion_def = criterion_def if r2r_local_lr_scheduler_def is not None: # get local lr to build r2r scheduler # if partial is used if hasattr(local_optimizer_def, "keywords"): clr = local_optimizer_def.keywords["lr"] # if lr is argumetn elif "lr" in inspect.signature(local_optimizer_def).parameters.keys(): clr = inspect.signature(local_optimizer_def).parameters["lr"].default else: raise Exception("lr not found in local optimizer class") # make a dummy optimizer so to directly use pytorch lr schedulers dummy_params = [ torch.tensor([1.0, 1.0], requires_grad=True), ] dummy_optimizer = torch.optim.SGD(params=dummy_params, lr=clr) r2r_scheduler = r2r_local_lr_scheduler_def(dummy_optimizer) def last_private_lr(sch): return sch._last_lr if not hasattr(r2r_scheduler, "get_last_lr"): r2r_scheduler.get_last_lr = partial(last_private_lr, r2r_scheduler) r2r_local_lr_scheduler = r2r_scheduler dummy_optimizer.step() else: r2r_local_lr_scheduler = None global_dataloaders = { key: DataLoader( dataset, batch_size=test_batch_size, pin_memory=True, ) for key, dataset in data_manager.get_global_dataset().items() } oracle_dataset = data_manager.get_oracle_dataset() # server storage self._server_memory = Storage() # initial storage writes # write and read protected entries self._server_memory.write( "data_manager", data_manager, read_protected=True, write_protected=True, ) self._server_memory.write( "num_clients", num_clients, read_protected=True, write_protected=True, ) self._server_memory.write( "sample_count", sample_count, read_protected=True, write_protected=True, ) self._server_memory.write( "sample_scheme", sample_scheme, read_protected=True, write_protected=True, ) self._server_memory.write( "oracle_dataset", oracle_dataset, read_protected=True, write_protected=True, ) self._server_memory.write( "last_client_sampled", None, read_protected=True, write_protected=True, ) # write protected entries self._server_memory.write( "test_batch_size", test_batch_size, write_protected=True, ) self._server_memory.write( "optimizer_def", optimizer_def, write_protected=True, ) self._server_memory.write( "lr_scheduler_def", lr_scheduler_def, write_protected=True, ) self._server_memory.write( "r2r_local_lr_scheduler", r2r_local_lr_scheduler, write_protected=True, ) self._server_memory.write( "criterion_def", criterion_def, write_protected=True, ) self._server_memory.write( "model_def", model_def, write_protected=True, ) self._server_memory.write( "metric_logger", metric_logger, write_protected=True, ) self._server_memory.write( "device", device, write_protected=True, ) self._server_memory.write( "global_dataloaders", global_dataloaders, write_protected=True, ) self._server_memory.write( "rounds", 0, write_protected=True, ) # client storage self._client_memory = {k: Storage() for k in range(num_clients)} self._local_cfg = Storage() # private client memory self._local_cfg.write("epochs", epochs, write_protected=True) self._local_cfg.write("batch_size", batch_size, write_protected=True) self._local_cfg.write("test_batch_size", test_batch_size, write_protected=True) self._local_cfg.write("criterion_def", criterion_def, write_protected=True) self._local_cfg.write( "local_optimizer_def", local_optimizer_def, write_protected=True, ) self._local_cfg.write( "local_lr_scheduler_def", local_lr_scheduler_def, write_protected=True, ) self._local_cfg.write("device", device) # this is over written in train method self._train_split_name = "train" self._server_scores = {key: dict() for key in global_dataloaders} self._client_scores = { key: dict() for key in data_manager.get_local_splits_names() } self.user_methods = dict( init=self.__class__.init, at_round_start=self.__class__.at_round_start, at_round_end=self.__class__.at_round_end, deploy=self.__class__.deploy, optimize=self.__class__.optimize, report=self.__class__.report, receive_from_client=self.__class__.receive_from_client, send_to_client=self.__class__.send_to_client, send_to_server=self.__class__.send_to_server, ) for key, value in self.user_methods.items(): sig = inspect.signature(value) if "self" in sig.parameters.keys(): raise Exception( f"Remove `self` from {key} arguments. " "All user methods should be static!" ) self.user_methods["init"](self._server_memory, *args, **kwargs) def _sample_clients(self): sample_scheme = self._server_memory.read("sample_scheme", silent=True) sample_count = self._server_memory.read("sample_count", silent=True) num_clients = self._server_memory.read("num_clients", silent=True) last_client_sampled = self._server_memory.read( "last_client_sampled", silent=True ) if sample_scheme == "uniform": clients = random.sample(range(num_clients), sample_count) elif sample_scheme == "sequential": last_sampled = -1 if last_client_sampled is None else last_client_sampled clients = [ (i + 1) % num_clients for i in range(last_sampled, last_sampled + sample_count) ] self._server_memory.write("last_client_sampled", clients[-1], silent=True) else: raise NotImplementedError return clients def _send_to_client(self, client_id): return self.user_methods["send_to_client"]( self._server_memory, client_id=client_id ) def _send_to_server(self, client_id): r2r_local_lr_scheduler = self._server_memory.read("r2r_local_lr_scheduler") data_manager = self._server_memory.read("data_manager", silent=True) rounds = self._server_memory.read("rounds") epochs = self._local_cfg.read("epochs") batch_size = self._local_cfg.read("batch_size") test_batch_size = self._local_cfg.read("test_batch_size") local_optimizer_def = self._local_cfg.read("local_optimizer_def") criterion_def = self._local_cfg.read("criterion_def") local_lr_scheduler_def = self._local_cfg.read("local_lr_scheduler_def") device = self._local_cfg.read("device") if r2r_local_lr_scheduler is None: local_optimizer_def = local_optimizer_def else: local_optimizer_def = partial( local_optimizer_def, lr=r2r_local_lr_scheduler.get_last_lr()[0], ) datasets = data_manager.get_local_dataset(client_id) round_scores = self.get_local_scores() storage = self._client_memory[client_id] train_split_name = self.get_train_split_name() client_ctx = self.user_methods["send_to_server"]( client_id, rounds, storage, datasets, train_split_name, round_scores, epochs, criterion_def(), batch_size, test_batch_size, local_optimizer_def, local_lr_scheduler_def, device, ctx=self._send_to_client(client_id), ) if not isinstance(client_ctx, dict): raise Exception("client should only return a dict!") return {**client_ctx, "client_id": client_id} def _receive_from_client(self, client_msg, serial_aggregator, appendix_aggregator): client_id = client_msg.pop("client_id") train_split_name = self.get_train_split_name() return self.user_methods["receive_from_client"]( self._server_memory, client_id, client_msg, train_split_name, serial_aggregator, appendix_aggregator, ) def _optimize(self, serial_aggregator, appendix_aggregator): reports = self.user_methods["optimize"]( self._server_memory, serial_aggregator, appendix_aggregator ) # purge aggregated results del serial_aggregator return reports def _report(self, round_scores, optimize_reports=None, deployment_points=None): global_dataloaders = self._server_memory.read("global_dataloaders", silent=True) metric_logger = self._server_memory.read("metric_logger") rounds = self._server_memory.read("rounds") device = self._server_memory.read("device") report_metrics = self.user_methods["report"]( self._server_memory, global_dataloaders, rounds, round_scores, metric_logger, device, optimize_reports, deployment_points, ) if metric_logger is not None: log_fn = metric_logger.log_scalar apply_on_dict(report_metrics, log_fn, step=rounds) return report_metrics def _train(self, rounds, num_score_report_point=None): diverged = False cur_round = self._server_memory.read("rounds") score_aggregator = AppendixAggregator(max_deque_lenght=num_score_report_point) for round_num in trange(rounds + 1): self._at_round_start() round_serial_aggregator = SerialAggregator() round_appendix_aggregator = AppendixAggregator() for client_id in self._sample_clients(): client_msg = self._send_to_server(client_id) success = self._receive_from_client( client_msg, round_serial_aggregator, round_appendix_aggregator, ) # signal divergence if not success: diverged = True break # check for divergence, early return if diverged: return score_aggregator.pop_all() # optimzie opt_reports = self._optimize( round_serial_aggregator, round_appendix_aggregator ) deploy_poiont = self.user_methods["deploy"](self._server_memory) round_scores = self.get_global_scores() score_dict = self._report(round_scores, opt_reports, deploy_poiont) score_aggregator.append_all(score_dict, step=cur_round) self._at_round_end(score_aggregator) self._server_memory.write("rounds", cur_round + round_num + 1, silent=True) return score_aggregator.pop_all() def _at_round_start(self) -> None: self.user_methods["at_round_start"](self._server_memory) def _at_round_end(self, score_aggregator) -> None: r2r_local_lr_scheduler = self._server_memory.read("r2r_local_lr_scheduler") if r2r_local_lr_scheduler is not None: r2r_local_lr_scheduler.step() self.user_methods["at_round_end"](self._server_memory, score_aggregator) def _get_round_scores(self, score_def_deck): # filter out the scores that should not be present in the current round rounds = self._server_memory.read("rounds") round_scores = dict() for name, definition in score_def_deck.items(): obj = definition() if rounds % obj.log_freq == 0: round_scores[name] = obj return round_scores # API functions
[docs] def train( self, rounds: int, num_score_report_point: Optional[int] = None, train_split_name="train", ) -> Optional[Dict[str, Optional[float]]]: r"""loop over the learning pipeline of distributed algorithm for given number of rounds. .. note:: * The clients metrics are reported in the form of clients.{metric_name}. * The server metrics (scores results) are reported in the form of server.{deployment_point}.{metric_name} Args: rounds (int): number of rounds to train. num_score_report_point (int): limits num of points to return reports. train_split_name (str): local split name to perform training on. Defaults to 'train'. Returns: Optional[Dict[str, Union[float]]]: collected score metrics. """ # store default split name default_split_name = self._train_split_name self._train_split_name = train_split_name ans = self._train( rounds=rounds, num_score_report_point=num_score_report_point, ) # restore default split name self._train_split_name = default_split_name return ans
[docs] def get_model_def(self): """To get the definition of the model so that one can instantiate it by calling. Returns: Callable: definition of the model. To instantiate, you may call the returned value with paranthesis in front. """ model_def = self._server_memory.read("model_def") return model_def
[docs] def get_server_storage(self): """To access the public configs of the server. Returns: Storage: public server storage. """ return self._server_memory
[docs] def get_round_number(self): """To get the current round number, starting from zero. Returns: int: current round number, starting from zero. """ return self._server_memory.read("rounds")
[docs] def get_train_split_name(self): """To get the name of the split used to perform local training. Returns: Hashable: name of the split used for local training. """ return self._train_split_name
[docs] def get_global_loader_split(self, split_name) -> Iterable: """To get the data loader for a specific global split. Args: split_name (Hashable): split name. Returns: Iterable: data loader for global split <split_name> """ return self._server_memory.read("global_dataloaders")[split_name]
[docs] def get_device(self) -> str: """To get the device name or number Returns: str: device name or number """ return self._server_memory.read("device")
[docs] def hook_local_score(self, score_def, score_name, split_name) -> None: """To hook a score measurment on local data. Args: score_def (Callable): definition of the score used to make new instances of. the list of existing scores could be found under ``fedsim.scores``. score_name (Hashable): name of the score to show up in the logs. split_name (Hashable): name of the data split to apply the measurement on. """ self._client_scores[split_name][score_name] = score_def
[docs] def hook_global_score(self, score_def, score_name, split_name) -> None: """To hook a score measurment on global data. Args: score_def (Callable): definition of the score used to make new instances of. the list of existing scores could be found under ``fedsim.scores``. score_name (Hashable): name of the score to show up in the logs. split_name (Hashable): name of the data split to apply the measurement on. """ self._server_scores[split_name][score_name] = score_def
[docs] def get_local_split_scores(self, split_name) -> Dict[str, Any]: """To instantiate and get local scores that have to be measured in the current round (log frequencies are matched) for a specific data split. Args: split_name (Hashable): name of the global data split Returns: Dict[str, Any]: mapping of name:score """ return self._get_round_scores(self._client_scores[split_name])
[docs] def get_global_split_scores(self, split_name) -> Dict[str, Any]: """To instantiate and get global scores that have to be measured in the current round (log frequencies are matched) for a specific data split. Args: split_name (Hashable): name of the global data split Returns: Dict[str, Any]: mapping of name:score. If no score is listed for the given split, None is returned. """ split_scores = self._get_round_scores(self._server_scores[split_name]) if len(split_scores) > 0: return split_scores return None
[docs] def get_local_scores(self) -> Dict[str, Any]: """To instantiate and get local scores that have to be measured in the current round (log frequencies are matched). Returns: Dict[str, Any]: mapping of name:score. If no score is listed for the given split, None is returned. """ scores = dict() for split_name in self._client_scores: split_score = self.get_local_split_scores(split_name) if split_score is not None: scores[split_name] = split_score return scores
[docs] def get_global_scores(self) -> Dict[str, Any]: """To instantiate and get global scores that have to be measured in the current round (log frequencies are matched). Returns: Dict[str, Any]: mapping of name:score """ scores = dict() for split_name in self._server_scores: split_score = self.get_global_split_scores(split_name) if split_score is not None: scores[split_name] = split_score return scores
# we do not do type hinting, however, the hints for abstract # methods are provided to help clarity for users
[docs] def init(server_storage: Storage, *args, **kwargs) -> None: """this method is executed only once at the time of instantiating the algorithm object. Here you define your model and whatever needed during the training. Remember to write the outcome of your processing to server_storage for access in other methods. .. note:: ``*args`` and ``**kwargs`` are directly passed through from algorithm constructor. Args: server_storage (Storage): server storage object """ pass
# optional methods
[docs] def at_round_start(server_storage: Storage) -> None: """to inject code at the beginning of rounds in training loop. Args: server_storage (Storage): server storage object. """ pass
[docs] def at_round_end( server_storage: Storage, score_aggregator: AppendixAggregator, ) -> None: """to inject code at the end of rounds in training loop Args: server_storage (Storage): server storage object. score_aggregator (AppendixAggregator): contains the aggregated scores """ pass
# abstract methods
[docs] def send_to_client(server_storage, client_id: int) -> Mapping[Hashable, Any]: """returns context to send to the client corresponding to client_id. .. warning:: Do not send shared objects like server model if you made any before you deepcopy it. Args: server_storage (Storage): server storage object. client_id (int): id of the receiving client Raises: NotImplementedError: abstract class to be implemented by child Returns: Mapping[Hashable, Any]: the context to be sent in form of a Mapping """ raise NotImplementedError( "Algorithm is missing the required 'send_to_client' function" )
[docs] def send_to_server( id: int, rounds: int, storage: Dict[Hashable, Any], datasets: Dict[str, Iterable], train_split_name: str, scores: Dict[str, Dict[str, Any]], epochs: int, criterion: nn.Module, train_batch_size: int, inference_batch_size: int, optimizer_def: Callable, lr_scheduler_def: Optional[Callable] = None, device: Union[int, str] = "cuda", ctx: Optional[Dict[Hashable, Any]] = None, *args, **kwargs, ) -> Mapping[str, Any]: """client operation on the recieved information. Args: id (int): id of the client rounds (int): global round number storage (Storage): storage object of the client datasets (Dict[str, Iterable]): this comes from Data Manager train_split_name (str): string containing name of the training split scores: Dict[str, Dict[str, Score]]: dictionary of form {'split_name':{'score_name': Score}} for global scores to evaluate at the current round. epochs (int): number of epochs to train criterion (Score): citerion, should be a differentiable fedsim.scores.score train_batch_size (int): training batch_size inference_batch_size (int): inference batch_size optimizer_def (float): class for constructing the local optimizer lr_scheduler_def (float): class for constructing the local lr scheduler device (Union[int, str], optional): Defaults to 'cuda'. ctx (Optional[Dict[Hashable, Any]], optional): context reveived. Returns: Mapping[str, Any]: client context to be sent to the server """ raise NotImplementedError( "Algorithm is missing the required 'send_to_client' function" )
[docs] def receive_from_client( server_storage: Storage, client_id: int, client_msg: Mapping[Hashable, Any], train_split_name: str, serial_aggregator: SerialAggregator, appendix_aggregator: AppendixAggregator, ) -> bool: """receive and aggregate info from selected clients Args: server_storage (Storage): server storage object. client_id (int): id of the sender (client) client_msg (Mapping[Hashable, Any]): client context that is sent. train_split_name (str): name of the training split on clients. aggregator (SerialAggregator): aggregator instance to collect info. Returns: bool: success of the aggregation. Raises: NotImplementedError: abstract class to be implemented by child """ raise NotImplementedError
[docs] def optimize( server_storage: Storage, serial_aggregator: SerialAggregator, appendix_aggregator: AppendixAggregator, ) -> Mapping[Hashable, Any]: """optimize server mdoel(s) and return scores to be reported Args: server_storage (Storage): server storage object. serial_aggregator (SerialAggregator): serial aggregator instance of current round. appendix_aggregator (AppendixAggregator): appendix aggregator instance of current round. Raises: NotImplementedError: abstract class to be implemented by child Returns: Mapping[Hashable, Any]: context to be reported """ raise NotImplementedError
[docs] def deploy(server_storage: Storage) -> Optional[Mapping[Hashable, Any]]: """return Mapping of name -> parameters_set to test the model Args: server_storage (Storage): server storage object. """ raise NotImplementedError
[docs] def report( server_storage: Storage, dataloaders: Dict[str, Any], round_scores: Dict[str, Dict[str, Any]], metric_logger: Optional[Any], device: str, optimize_reports: Mapping[Hashable, Any], deployment_points: Optional[Mapping[Hashable, torch.Tensor]] = None, ) -> Dict[str, Union[int, float]]: """test on global data and report info. If a flatten dict of str:Union[int,float] is returned from this function the content is automatically logged using the metric logger (e.g., logall.TensorboardLogger). metric_logger is also passed as an input argument for extra logging operations (non scalar). Args: server_storage (Storage): server storage object. dataloaders (Any): dict of data loaders to test the global model(s) round_scores (Dict[str, Dict[str, fedsim.scores.Score]]): dictionary of form {'split_name':{'score_name': score_def}} for global scores to evaluate at the current round. metric_logger (Any, optional): the logging object (e.g., logall.TensorboardLogger) device (str): 'cuda', 'cpu' or gpu number optimize_reports (Mapping[Hashable, Any]): dict returned by optimzier deployment_points (Mapping[Hashable, torch.Tensor], optional): \ output of deploy method Raises: NotImplementedError: abstract class to be implemented by child """ raise NotImplementedError