AvgLogits#
- class FedDF(data_manager, metric_logger, num_clients, sample_scheme, sample_rate, model_def, epochs, criterion_def, optimizer_def=functools.partial(<class 'torch.optim.sgd.SGD'>, lr=1.0), local_optimizer_def=functools.partial(<class 'torch.optim.sgd.SGD'>, lr=0.1), lr_scheduler_def=None, local_lr_scheduler_def=None, r2r_local_lr_scheduler_def=None, batch_size=32, test_batch_size=64, device='cpu', *args, **kwargs)[source]#
Ensemble Distillation for Robust Model Fusion in Federated Learning.
For further details regarding the algorithm we refer to Ensemble Distillation for Robust Model Fusion in Federated Learning.
- Parameters
data_manager (
distributed.data_management.DataManager
) -- data managermetric_logger (
logall.Logger
) -- metric logger for tracking.num_clients (int) -- number of clients
sample_scheme (
str
) -- mode of sampling clients. Options are'uniform'
and'sequential'
sample_rate (
float
) -- rate of sampling clientsmodel_def (
torch.Module
) -- definition of for constructing the modelepochs (
int
) -- number of local epochscriterion_def (
Callable
) -- loss function defining local objectiveoptimizer_def (
Callable
) -- derfintion of server optimizerlocal_optimizer_def (
Callable
) -- defintoin of local optimizerlr_scheduler_def (
Callable
) -- definition of lr scheduler of server optimizer.local_lr_scheduler_def (
Callable
) -- definition of lr scheduler of local optimizerr2r_local_lr_scheduler_def (
Callable
) -- definition to schedule lr that is delivered to the clients at each round (deterimined init lr of the client optimizer)batch_size (int) -- batch size of the local trianing
test_batch_size (int) -- inference time batch size
device (str) -- cpu, cuda, or gpu number
global_train_split (str) -- the name of train split to be used on server
global_epochs (int) -- number of training epochs on the server
Note
- definition of
- learning rate schedulers, could be any of the ones defined at
torch.optim.lr_scheduler
or any other that implements step and get_last_lr methods._schedulers``.
optimizers, could be any
torch.optim.Optimizer
.model, could be any
torch.Module
.criterion, could be any
fedsim.scores.Score
.
Warning
this algorithm needs a split for training on the server. This means that the global datasets provided in data manager should include an extra split.
- init(*args, **kwrag)[source]#
this method is executed only once at the time of instantiating the algorithm object. Here you define your model and whatever needed during the training. Remember to write the outcome of your processing to server_storage for access in other methods.
Note
*args
and**kwargs
are directly passed through from algorithm constructor.- Parameters
server_storage (Storage) -- server storage object
- optimize(serial_aggregator, appendix_aggregator)[source]#
optimize server mdoel(s) and return scores to be reported
- Parameters
server_storage (Storage) -- server storage object.
serial_aggregator (SerialAggregator) -- serial aggregator instance of current round.
appendix_aggregator (AppendixAggregator) -- appendix aggregator instance of current round.
- Raises
NotImplementedError -- abstract class to be implemented by child
- Returns
Mapping[Hashable, Any] -- context to be reported
- receive_from_client(client_id, client_msg, train_split_name, serial_aggregator, appendix_aggregator)[source]#
receive and aggregate info from selected clients
- Parameters
server_storage (Storage) -- server storage object.
client_id (int) -- id of the sender (client)
client_msg (Mapping[Hashable, Any]) -- client context that is sent.
train_split_name (str) -- name of the training split on clients.
aggregator (SerialAggregator) -- aggregator instance to collect info.
- Returns
bool -- success of the aggregation.
- Raises
NotImplementedError -- abstract class to be implemented by child