FedNova#
- class FedNova(data_manager, metric_logger, num_clients, sample_scheme, sample_rate, model_def, epochs, criterion_def, optimizer_def=functools.partial(<class 'torch.optim.sgd.SGD'>, lr=1.0), local_optimizer_def=functools.partial(<class 'torch.optim.sgd.SGD'>, lr=0.1), lr_scheduler_def=None, local_lr_scheduler_def=None, r2r_local_lr_scheduler_def=None, batch_size=32, test_batch_size=64, device='cpu', *args, **kwargs)[source]#
Implements FedNova algorithm for centralized FL.
For further details regarding the algorithm we refer to Tackling the Objective Inconsistency Problem in Heterogeneous Federated Optimization.
- Parameters
data_manager (
distributed.data_management.DataManager
) -- data managermetric_logger (
logall.Logger
) -- metric logger for tracking.num_clients (int) -- number of clients
sample_scheme (
str
) -- mode of sampling clients. Options are'uniform'
and'sequential'
sample_rate (
float
) -- rate of sampling clientsmodel_def (
torch.Module
) -- definition of for constructing the modelepochs (
int
) -- number of local epochscriterion_def (
Callable
) -- loss function defining local objectiveoptimizer_def (
Callable
) -- derfintion of server optimizerlocal_optimizer_def (
Callable
) -- defintoin of local optimizerlr_scheduler_def (
Callable
) -- definition of lr scheduler of server optimizer.local_lr_scheduler_def (
Callable
) -- definition of lr scheduler of local optimizerr2r_local_lr_scheduler_def (
Callable
) -- definition to schedule lr that is delivered to the clients at each round (deterimined init lr of the client optimizer)batch_size (int) -- batch size of the local trianing
test_batch_size (int) -- inference time batch size
device (str) -- cpu, cuda, or gpu number
Note
- definition of
- learning rate schedulers, could be any of the ones defined at
torch.optim.lr_scheduler
or any other that implements step and get_last_lr methods._schedulers``.
optimizers, could be any
torch.optim.Optimizer
.model, could be any
torch.Module
.criterion, could be any
fedsim.scores.Score
.
- receive_from_client(client_id, client_msg, train_split_name, serial_aggregator, appendix_aggregator)[source]#
receive and aggregate info from selected clients
- Parameters
server_storage (Storage) -- server storage object.
client_id (int) -- id of the sender (client)
client_msg (Mapping[Hashable, Any]) -- client context that is sent.
train_split_name (str) -- name of the training split on clients.
aggregator (SerialAggregator) -- aggregator instance to collect info.
- Returns
bool -- success of the aggregation.
- Raises
NotImplementedError -- abstract class to be implemented by child