Source code for fedsim.distributed.centralized.training.feddyn

r"""
FedDyn
-------
"""
from functools import partial

import torch

from fedsim.local.training.step_closures import default_step_closure
from fedsim.utils import vector_to_parameters_like
from fedsim.utils import vectorize_module

from .fedavg import FedAvg
from .utils import serial_aggregation


[docs]class FedDyn(FedAvg): r"""Implements FedDyn algorithm for centralized FL. For further details regarding the algorithm we refer to `Federated Learning Based on Dynamic Regularization`_. Args: data_manager (``distributed.data_management.DataManager``): data manager metric_logger (``logall.Logger``): metric logger for tracking. num_clients (int): number of clients sample_scheme (``str``): mode of sampling clients. Options are ``'uniform'`` and ``'sequential'`` sample_rate (``float``): rate of sampling clients model_def (``torch.Module``): definition of for constructing the model epochs (``int``): number of local epochs criterion_def (``Callable``): loss function defining local objective optimizer_def (``Callable``): derfintion of server optimizer local_optimizer_def (``Callable``): defintoin of local optimizer lr_scheduler_def (``Callable``): definition of lr scheduler of server optimizer. local_lr_scheduler_def (``Callable``): definition of lr scheduler of local optimizer r2r_local_lr_scheduler_def (``Callable``): definition to schedule lr that is delivered to the clients at each round (deterimined init lr of the client optimizer) batch_size (int): batch size of the local trianing test_batch_size (int): inference time batch size device (str): cpu, cuda, or gpu number alpha (float): FedDyn's :math:`\alpha` hyper-parameter for local regularization .. note:: definition of * learning rate schedulers, could be any of the ones defined at ``torch.optim.lr_scheduler`` or any other that implements step and get_last_lr methods._schedulers``. * optimizers, could be any ``torch.optim.Optimizer``. * model, could be any ``torch.Module``. * criterion, could be any ``fedsim.scores.Score``. .. _Federated Learning Based on Dynamic Regularization: https://openreview.net/forum?id=B7v4QMR6Z9w """
[docs] def init(server_storage, *args, **kwrag): default_alpha = 0.1 local_train_split_name = "train" FedAvg.init(server_storage) cloud_params = server_storage.read("cloud_params") server_storage.write("avg_params", cloud_params.clone().detach()) server_storage.write("h", torch.zeros_like(cloud_params)) server_storage.write("average_sample", 0) server_storage.write("alpha", kwrag.get("alpha", default_alpha)) # oracle read violation, num_clients read violation print("Warning: private access violation") print("\t", end="") print( "FedDyn assumes prior knowledge on the number of clients and oracle samples" ) oracle_dataset = server_storage.read("oracle_dataset")[local_train_split_name] num_clients = server_storage.read("num_clients") average_sample = len(oracle_dataset) / num_clients server_storage.write("average_sample", average_sample)
[docs] def send_to_client(server_storage, client_id): msg = FedAvg.send_to_client(server_storage, client_id) msg["average_sample"] = server_storage.read("average_sample") msg["alpha"] = server_storage.read("alpha") return msg
[docs] def send_to_server( id, rounds, storage, datasets, train_split_name, metrics, epochs, criterion, train_batch_size, inference_batch_size, optimizer_def, lr_scheduler_def=None, device="cuda", ctx=None, step_closure=None, ): model = ctx["model"] alpha = ctx["alpha"] average_sample = ctx["average_sample"] params_init = vectorize_module(model, clone=True, detach=True) h = storage.read("h") alpha_adaptive = alpha / len(datasets[train_split_name]) * average_sample def transform_grads_fn(model): params = vectorize_module(model) grad_additive = 0.5 * (params - params_init) if h is not None: grad_additive -= h grad_additive_list = vector_to_parameters_like( alpha_adaptive * grad_additive, model.parameters() ) for p, g_a in zip(model.parameters(), grad_additive_list): p.grad += g_a step_closure_ = partial( default_step_closure, transform_grads=transform_grads_fn ) opt_res = FedAvg.send_to_server( id, rounds, storage, datasets, train_split_name, metrics, epochs, criterion, train_batch_size, inference_batch_size, optimizer_def, lr_scheduler_def, device, ctx, step_closure=step_closure_, ) # update local h pseudo_grads = params_init - vectorize_module(model, clone=True, detach=True) new_h = pseudo_grads if h is None else pseudo_grads + h storage.write("h", new_h) return opt_res
[docs] def receive_from_client( server_storage, client_id, client_msg, train_split_name, serial_aggregator, appendix_aggregator, ): weight = 1 return serial_aggregation( server_storage, client_id, client_msg, train_split_name, serial_aggregator, train_weight=weight, )
[docs] def optimize(server_storage, serial_aggregator, appendix_aggregator): if "local_params" in serial_aggregator: weight = serial_aggregator.get_weight("local_params") param_avg = serial_aggregator.pop("local_params") optimizer = server_storage.read("optimizer") lr_scheduler = server_storage.read("lr_scheduler") cloud_params = server_storage.read("cloud_params") pseudo_grads = cloud_params.data - param_avg h = server_storage.read("h") num_clients = server_storage.read("num_clients") # already warned, disable read protection from num_clients server_storage.change_protection("num_clients", False, False) h = h + weight / num_clients * pseudo_grads new_params = param_avg - h modified_pseudo_grads = cloud_params.data - new_params # update cloud params optimizer.zero_grad() cloud_params.grad = modified_pseudo_grads optimizer.step() if lr_scheduler is not None: lr_scheduler.step() server_storage.write("avg_params", param_avg.detach().clone()) server_storage.write("h", h.data) # purge aggregated results del param_avg return serial_aggregator.pop_all()
[docs] def deploy(server_storage): return dict( cloud=server_storage.read("cloud_params"), avg=server_storage.read("avg_params"), )