Source code for fedsim.distributed.centralized.training.fedprox

r"""
FedProx
-------
"""
from functools import partial

from torch.nn.utils import parameters_to_vector

from fedsim.local.training.step_closures import default_step_closure
from fedsim.utils import vector_to_parameters_like
from fedsim.utils import vectorize_module

from .fedavg import FedAvg


[docs]class FedProx(FedAvg): r"""Implements FedProx algorithm for centralized FL. For further details regarding the algorithm we refer to `Federated Optimization in Heterogeneous Networks`_. Args: data_manager (``distributed.data_management.DataManager``): data manager metric_logger (``logall.Logger``): metric logger for tracking. num_clients (int): number of clients sample_scheme (``str``): mode of sampling clients. Options are ``'uniform'`` and ``'sequential'`` sample_rate (``float``): rate of sampling clients model_def (``torch.Module``): definition of for constructing the model epochs (``int``): number of local epochs criterion_def (``Callable``): loss function defining local objective optimizer_def (``Callable``): derfintion of server optimizer local_optimizer_def (``Callable``): defintoin of local optimizer lr_scheduler_def (``Callable``): definition of lr scheduler of server optimizer. local_lr_scheduler_def (``Callable``): definition of lr scheduler of local optimizer r2r_local_lr_scheduler_def (``Callable``): definition to schedule lr that is delivered to the clients at each round (deterimined init lr of the client optimizer) batch_size (int): batch size of the local trianing test_batch_size (int): inference time batch size device (str): cpu, cuda, or gpu number mu (float): FedProx's :math:`\mu` hyper-parameter for local regularization .. note:: definition of * learning rate schedulers, could be any of the ones defined at ``torch.optim.lr_scheduler`` or any other that implements step and get_last_lr methods._schedulers``. * optimizers, could be any ``torch.optim.Optimizer``. * model, could be any ``torch.Module``. * criterion, could be any ``fedsim.scores.Score``. .. _Federated Optimization in Heterogeneous Networks: https://arxiv.org/abs/1812.06127 """
[docs] def init(server_storage, *args, **kwrag): default_mu = 0.0001 FedAvg.init(server_storage) server_storage.write("mu", kwrag.get("mu", default_mu))
[docs] def send_to_client(server_storage, client_id): server_msg = FedAvg.send_to_client(server_storage, client_id) server_msg["mu"] = server_storage.read("mu") return server_msg
[docs] def send_to_server( id, rounds, storage, datasets, train_split_name, scores, epochs, criterion, train_batch_size, inference_batch_size, optimizer_def, lr_scheduler_def=None, device="cuda", ctx=None, step_closure=None, ): model = ctx["model"] mu = ctx["mu"] params_init = vectorize_module(model, clone=True, detach=True) def transform_grads_fn(model): params = parameters_to_vector(model.parameters()) grad_additive = 0.5 * (params - params_init) grad_additive_list = vector_to_parameters_like( mu * grad_additive, model.parameters() ) for p, g_a in zip(model.parameters(), grad_additive_list): p.grad += g_a step_closure_ = partial( default_step_closure, transform_grads=transform_grads_fn ) return FedAvg.send_to_server( id, rounds, storage, datasets, train_split_name, scores, epochs, criterion, train_batch_size, inference_batch_size, optimizer_def, lr_scheduler_def, device, ctx, step_closure=step_closure_, )