Source code for fedsim.distributed.centralized.training.feddf

r"""
AvgLogits
---------
"""

from torch.nn.functional import log_softmax
from torch.nn.utils.stateless import functional_call

from fedsim.scores import KLDivScore
from fedsim.utils import SerialAggregator
from fedsim.utils import initialize_module
from fedsim.utils import vector_to_named_parameters_like

from .fedavg import FedAvg
from .utils import serial_aggregation


[docs]class FedDF(FedAvg): r"""Ensemble Distillation for Robust Model Fusion in Federated Learning. For further details regarding the algorithm we refer to `Ensemble Distillation for Robust Model Fusion in Federated Learning`_. Args: data_manager (``distributed.data_management.DataManager``): data manager metric_logger (``logall.Logger``): metric logger for tracking. num_clients (int): number of clients sample_scheme (``str``): mode of sampling clients. Options are ``'uniform'`` and ``'sequential'`` sample_rate (``float``): rate of sampling clients model_def (``torch.Module``): definition of for constructing the model epochs (``int``): number of local epochs criterion_def (``Callable``): loss function defining local objective optimizer_def (``Callable``): derfintion of server optimizer local_optimizer_def (``Callable``): defintoin of local optimizer lr_scheduler_def (``Callable``): definition of lr scheduler of server optimizer. local_lr_scheduler_def (``Callable``): definition of lr scheduler of local optimizer r2r_local_lr_scheduler_def (``Callable``): definition to schedule lr that is delivered to the clients at each round (deterimined init lr of the client optimizer) batch_size (int): batch size of the local trianing test_batch_size (int): inference time batch size device (str): cpu, cuda, or gpu number global_train_split (str): the name of train split to be used on server global_epochs (int): number of training epochs on the server .. note:: definition of * learning rate schedulers, could be any of the ones defined at ``torch.optim.lr_scheduler`` or any other that implements step and get_last_lr methods._schedulers``. * optimizers, could be any ``torch.optim.Optimizer``. * model, could be any ``torch.Module``. * criterion, could be any ``fedsim.scores.Score``. .. warning:: this algorithm needs a split for training on the server. This means that the global datasets provided in data manager should include an extra split. .. _Ensemble Distillation for Robust Model Fusion in Federated Learning: https://openreview.net/forum?id=gjrMQoAhSRq """
[docs] def init(server_storage, *args, **kwrag): default_global_train_split = "valid" default_global_epochs = 1 FedAvg.init(server_storage) server_storage.write( "global_train_split", kwrag.get("global_train_split", default_global_train_split), ) server_storage.write( "global_epochs", kwrag.get("global_epochs", default_global_epochs), )
[docs] def receive_from_client( server_storage, client_id, client_msg, train_split_name, serial_aggregator, appendix_aggregator, ): params = client_msg["local_params"].clone().detach().data appendix_aggregator.append("local_params", params) return serial_aggregation( server_storage, client_id, client_msg, train_split_name, serial_aggregator, )
[docs] def optimize(server_storage, serial_aggregator, appendix_aggregator): if "local_params" in serial_aggregator: param_avg = serial_aggregator.pop("local_params") cloud_params = server_storage.read("cloud_params") cloud_params.data = param_avg.data model = server_storage.read("model") global_train_split = server_storage.read("global_train_split") train_data_loader = server_storage.read("global_dataloaders").get( global_train_split ) if train_data_loader is None: raise Exception( f"no dataloader made for split {global_train_split} on the server!" ) global_epochs = server_storage.read("global_epochs") optimizer = server_storage.read("optimizer") lr_scheduler = server_storage.read("lr_scheduler") device = server_storage.read("device") for _ in range(global_epochs): for x, _ in train_data_loader: x = x.to(device) target_agg = SerialAggregator() for local_params in appendix_aggregator.get_values("local_params"): initialize_module(model, local_params) target = model(x).clone().detach() target_agg.add("target", target, 1) target_out = log_softmax(target_agg.get("target"), 1) # initialize_module(model, cloud_params) criterion = KLDivScore(log_target=True) param_dict = vector_to_named_parameters_like( cloud_params, model.named_parameters() ) pred = functional_call(model, param_dict, x) pred_out = log_softmax(pred, 1) loss = criterion(pred_out, target_out.data) optimizer.zero_grad() loss.backward() optimizer.step() if lr_scheduler is not None: lr_scheduler.step() del target_agg return serial_aggregator.pop_all()