Source code for fedsim.distributed.centralized.training.adabest

r"""
AdaBest
-------
"""
from functools import partial

import torch

from fedsim.local.training.step_closures import default_step_closure
from fedsim.utils import SerialAggregator
from fedsim.utils import vector_to_parameters_like
from fedsim.utils import vectorize_module

from .fedavg import FedAvg
from .utils import serial_aggregation


[docs]class AdaBest(FedAvg): r"""Implements AdaBest algorithm for centralized FL. For further details regarding the algorithm we refer to `AdaBest: Minimizing Client Drift in Federated Learning via Adaptive Bias Estimation`_. Args: data_manager (``distributed.data_management.DataManager``): data manager metric_logger (``logall.Logger``): metric logger for tracking. num_clients (int): number of clients sample_scheme (``str``): mode of sampling clients. Options are ``'uniform'`` and ``'sequential'`` sample_rate (``float``): rate of sampling clients model_def (``torch.Module``): definition of for constructing the model epochs (``int``): number of local epochs criterion_def (``Callable``): loss function defining local objective optimizer_def (``Callable``): derfintion of server optimizer local_optimizer_def (``Callable``): defintoin of local optimizer lr_scheduler_def (``Callable``): definition of lr scheduler of server optimizer. local_lr_scheduler_def (``Callable``): definition of lr scheduler of local optimizer r2r_local_lr_scheduler_def (``Callable``): definition to schedule lr that is delivered to the clients at each round (deterimined init lr of the client optimizer) batch_size (int): batch size of the local trianing test_batch_size (int): inference time batch size device (str): cpu, cuda, or gpu number mu (float): AdaBest's :math:`\mu` hyper-parameter for local regularization beta (float): AdaBest's :math:`\beta` hyper-parameter for global regularization .. note:: definition of * learning rate schedulers, could be any of the ones defined at ``torch.optim.lr_scheduler`` or any other that implements step and get_last_lr methods._schedulers``. * optimizers, could be any ``torch.optim.Optimizer``. * model, could be any ``torch.Module``. * criterion, could be any ``fedsim.scores.Score``. .. _AdaBest\: Minimizing Client Drift in Federated Learning via Adaptive Bias Estimation: https://arxiv.org/abs/2204.13170 """
[docs] def init(server_storage, *args, **kwrag): default_mu = 0.1 default_beta = 0.85 FedAvg.init(server_storage) cloud_params = server_storage.read("cloud_params") server_storage.write("avg_params", cloud_params.clone().detach()) server_storage.write("h", torch.zeros_like(cloud_params)) server_storage.write("average_sample", 0) server_storage.write("mu", kwrag.get("mu", default_mu)) server_storage.write("beta", kwrag.get("beta", default_beta)) running_stats = SerialAggregator() running_stats.add("avg_m", 0, 0) server_storage.write("running_stats", running_stats)
[docs] def send_to_client(server_storage, client_id): msg = FedAvg.send_to_client(server_storage, client_id) msg["average_sample"] = server_storage.read("running_stats").get("avg_m") msg["mu"] = server_storage.read("mu") return msg
[docs] def send_to_server( id, rounds, storage, datasets, train_split_name, scores, epochs, criterion, train_batch_size, inference_batch_size, optimizer_def, lr_scheduler_def=None, device="cuda", ctx=None, step_closure=None, ): model = ctx["model"] mu = ctx["mu"] average_sample = ctx["average_sample"] params_init = vectorize_module(model, clone=True, detach=True) h = storage.read("h") mu_adaptive = mu / len(datasets[train_split_name]) * average_sample * epochs def transform_grads_fn(model): if h is not None: grad_additive = -h grad_additive_list = vector_to_parameters_like( mu_adaptive * grad_additive, model.parameters() ) for p, g_a in zip(model.parameters(), grad_additive_list): p.grad += g_a step_closure_ = partial( default_step_closure, transform_grads=transform_grads_fn ) opt_res = FedAvg.send_to_server( id, rounds, storage, datasets, train_split_name, scores, epochs, criterion, train_batch_size, inference_batch_size, optimizer_def, lr_scheduler_def, device, ctx, step_closure=step_closure_, ) # update local h pseudo_grads = params_init - vectorize_module(model, clone=True, detach=True) last_round = storage.read("last_round") if h is None: new_h = pseudo_grads else: if last_round is None: last_round = rounds - 1 new_h = 1 / (rounds - last_round) * h + pseudo_grads storage.write("h", new_h) storage.write("last_round", rounds) return opt_res
[docs] def receive_from_client( server_storage, client_id, client_msg, train_split_name, serial_aggregator, appendix_aggregator, ): weight = 1 agg_res = serial_aggregation( server_storage, client_id, client_msg, train_split_name, serial_aggregator, train_weight=weight, ) n_train = client_msg["num_samples"][train_split_name] running_stats = server_storage.read("running_stats") running_stats.add("avg_m", n_train / client_msg["num_steps"], 1) return agg_res
[docs] def optimize(server_storage, serial_aggregator, appendix_aggregator): if "local_params" in serial_aggregator: beta = server_storage.read("beta") param_avg = serial_aggregator.pop("local_params") optimizer = server_storage.read("optimizer") lr_scheduler = server_storage.read("lr_scheduler") cloud_params = server_storage.read("cloud_params") h = beta * (server_storage.read("avg_params") - param_avg) new_params = param_avg - h modified_pseudo_grads = cloud_params.data - new_params # update cloud params optimizer.zero_grad() cloud_params.grad = modified_pseudo_grads optimizer.step() if lr_scheduler is not None: lr_scheduler.step() server_storage.write("avg_params", param_avg.detach().clone()) return serial_aggregator.pop_all()
[docs] def deploy(server_storage): return dict( cloud=server_storage.read("cloud_params"), avg=server_storage.read("avg_params"), )