Source code for fedsim.distributed.centralized.training.fednova
r"""
FedNova
-------
"""
from .fedavg import FedAvg
from .utils import serial_aggregation
[docs]class FedNova(FedAvg):
r"""Implements FedNova algorithm for centralized FL.
For further details regarding the algorithm we refer to `Tackling the Objective
Inconsistency Problem in Heterogeneous Federated Optimization`_.
Args:
data_manager (``distributed.data_management.DataManager``): data manager
metric_logger (``logall.Logger``): metric logger for tracking.
num_clients (int): number of clients
sample_scheme (``str``): mode of sampling clients. Options are ``'uniform'``
and ``'sequential'``
sample_rate (``float``): rate of sampling clients
model_def (``torch.Module``): definition of for constructing the model
epochs (``int``): number of local epochs
criterion_def (``Callable``): loss function defining local objective
optimizer_def (``Callable``): derfintion of server optimizer
local_optimizer_def (``Callable``): defintoin of local optimizer
lr_scheduler_def (``Callable``): definition of lr scheduler of server optimizer.
local_lr_scheduler_def (``Callable``): definition of lr scheduler of local
optimizer
r2r_local_lr_scheduler_def (``Callable``): definition to schedule lr that is
delivered to the clients at each round (deterimined init lr of the
client optimizer)
batch_size (int): batch size of the local trianing
test_batch_size (int): inference time batch size
device (str): cpu, cuda, or gpu number
.. note::
definition of
* learning rate schedulers, could be any of the ones defined at
``torch.optim.lr_scheduler`` or any other that implements step and
get_last_lr methods._schedulers``.
* optimizers, could be any ``torch.optim.Optimizer``.
* model, could be any ``torch.Module``.
* criterion, could be any ``fedsim.scores.Score``.
.. _Tackling the Objective Inconsistency Problem in Heterogeneous Federated
Optimization: https://arxiv.org/abs/2007.07481
"""
[docs] def receive_from_client(
server_storage,
client_id,
client_msg,
train_split_name,
serial_aggregator,
appendix_aggregator,
):
n_train = client_msg["num_samples"][train_split_name]
weight = n_train / client_msg["num_steps"]
return serial_aggregation(
server_storage,
client_id,
client_msg,
train_split_name,
serial_aggregator,
train_weight=weight,
)