FedNova#

class FedNova(data_manager, metric_logger, num_clients, sample_scheme, sample_rate, model_def, epochs, criterion_def, optimizer_def=functools.partial(<class 'torch.optim.sgd.SGD'>, lr=1.0), local_optimizer_def=functools.partial(<class 'torch.optim.sgd.SGD'>, lr=0.1), lr_scheduler_def=None, local_lr_scheduler_def=None, r2r_local_lr_scheduler_def=None, batch_size=32, test_batch_size=64, device='cpu', *args, **kwargs)[source]#

Implements FedNova algorithm for centralized FL.

For further details regarding the algorithm we refer to Tackling the Objective Inconsistency Problem in Heterogeneous Federated Optimization.

Parameters
  • data_manager (distributed.data_management.DataManager) -- data manager

  • metric_logger (logall.Logger) -- metric logger for tracking.

  • num_clients (int) -- number of clients

  • sample_scheme (str) -- mode of sampling clients. Options are 'uniform' and 'sequential'

  • sample_rate (float) -- rate of sampling clients

  • model_def (torch.Module) -- definition of for constructing the model

  • epochs (int) -- number of local epochs

  • criterion_def (Callable) -- loss function defining local objective

  • optimizer_def (Callable) -- derfintion of server optimizer

  • local_optimizer_def (Callable) -- defintoin of local optimizer

  • lr_scheduler_def (Callable) -- definition of lr scheduler of server optimizer.

  • local_lr_scheduler_def (Callable) -- definition of lr scheduler of local optimizer

  • r2r_local_lr_scheduler_def (Callable) -- definition to schedule lr that is delivered to the clients at each round (deterimined init lr of the client optimizer)

  • batch_size (int) -- batch size of the local trianing

  • test_batch_size (int) -- inference time batch size

  • device (str) -- cpu, cuda, or gpu number

Note

definition of
  • learning rate schedulers, could be any of the ones defined at

    torch.optim.lr_scheduler or any other that implements step and get_last_lr methods._schedulers``.

  • optimizers, could be any torch.optim.Optimizer.

  • model, could be any torch.Module.

  • criterion, could be any fedsim.scores.Score.

receive_from_client(client_id, client_msg, train_split_name, serial_aggregator, appendix_aggregator)[source]#

receive and aggregate info from selected clients

Parameters
  • server_storage (Storage) -- server storage object.

  • client_id (int) -- id of the sender (client)

  • client_msg (Mapping[Hashable, Any]) -- client context that is sent.

  • train_split_name (str) -- name of the training split on clients.

  • aggregator (SerialAggregator) -- aggregator instance to collect info.

Returns

bool -- success of the aggregation.

Raises

NotImplementedError -- abstract class to be implemented by child