r"""
FedAvg
------
"""
import math
from torch.utils.data import DataLoader
from torch.utils.data import RandomSampler
from fedsim.local.training import local_inference
from fedsim.local.training import local_train
from fedsim.local.training.step_closures import default_step_closure
from fedsim.utils import initialize_module
from fedsim.utils import vectorize_module
from ..centralized_fl_algorithm import CentralFLAlgorithm
from .utils import serial_aggregation
# from ._shared_docs import doc_args, doc_arc, doc_note
[docs]class FedAvg(CentralFLAlgorithm):
r"""
Implements FedAvg algorithm for centralized FL.
For further details regarding the algorithm we refer to `Communication-Efficient
Learning of Deep Networks from Decentralized Data`_.
Args:
data_manager (``distributed.data_management.DataManager``): data manager
metric_logger (``logall.Logger``): metric logger for tracking.
num_clients (int): number of clients
sample_scheme (``str``): mode of sampling clients. Options are ``'uniform'``
and ``'sequential'``
sample_rate (``float``): rate of sampling clients
model_def (``torch.Module``): definition of for constructing the model
epochs (``int``): number of local epochs
criterion_def (``Callable``): loss function defining local objective
optimizer_def (``Callable``): derfintion of server optimizer
local_optimizer_def (``Callable``): defintoin of local optimizer
lr_scheduler_def (``Callable``): definition of lr scheduler of server optimizer.
local_lr_scheduler_def (``Callable``): definition of lr scheduler of local
optimizer
r2r_local_lr_scheduler_def (``Callable``): definition to schedule lr that is
delivered to the clients at each round (deterimined init lr of the
client optimizer)
batch_size (int): batch size of the local trianing
test_batch_size (int): inference time batch size
device (str): cpu, cuda, or gpu number
.. note::
definition of
* learning rate schedulers, could be any of the ones defined at
``torch.optim.lr_scheduler`` or any other that implements step and
get_last_lr methods._schedulers``.
* optimizers, could be any ``torch.optim.Optimizer``.
* model, could be any ``torch.Module``.
* criterion, could be any ``fedsim.scores.Score``.
.. _Communication-Efficient Learning of Deep Networks from Decentralized
Data: https://arxiv.org/abs/1602.05629
"""
[docs] def init(server_storage):
device = server_storage.read("device")
model = server_storage.read("model_def")().to(device)
params = vectorize_module(model, clone=True, detach=True)
params.requires_grad = True
optimizer = server_storage.read("optimizer_def")(params=[params])
lr_scheduler = None
lr_scheduler_def = server_storage.read("lr_scheduler_def")
if lr_scheduler_def is not None:
lr_scheduler = lr_scheduler_def(optimizer=optimizer)
server_storage.write("model", model)
server_storage.write("cloud_params", params)
server_storage.write("optimizer", optimizer)
server_storage.write("lr_scheduler", lr_scheduler)
[docs] def send_to_client(server_storage, client_id):
# load cloud stuff
cloud_params = server_storage.read("cloud_params")
model = server_storage.read("model")
# copy cloud params to cloud model to send to the client
initialize_module(model, cloud_params, clone=True, detach=True)
# return a copy of the cloud model
return dict(model=model)
# define client operation
[docs] def send_to_server(
id,
rounds,
storage,
datasets,
train_split_name,
scores,
epochs,
criterion,
train_batch_size,
inference_batch_size,
optimizer_def,
lr_scheduler_def=None,
device="cuda",
ctx=None,
step_closure=None,
):
# create a random sampler with replacement so that
# stochasticity is maximiazed and privacy is not compromized
sampler = RandomSampler(
datasets[train_split_name],
replacement=True,
num_samples=math.ceil(len(datasets[train_split_name]) / train_batch_size)
* train_batch_size,
)
# # create train data loader
train_loader = DataLoader(
datasets[train_split_name], batch_size=train_batch_size, sampler=sampler
)
model = ctx["model"]
optimizer = optimizer_def(model.parameters())
if lr_scheduler_def is not None:
lr_scheduler = lr_scheduler_def(optimizer=optimizer)
else:
lr_scheduler = None
# optimize the model locally
step_closure_ = default_step_closure if step_closure is None else step_closure
if train_split_name in scores:
train_scores = scores[train_split_name]
else:
train_scores = dict()
num_train_samples, num_steps, diverged, = local_train(
model,
train_loader,
epochs,
0,
criterion,
optimizer,
lr_scheduler,
device,
step_closure_,
scores=train_scores,
)
# get average train scores
metrics_dict = {
train_split_name: {
name: score.get_score() for name, score in train_scores.items()
}
}
# append train loss
if rounds % criterion.log_freq == 0:
metrics_dict[train_split_name][criterion.get_name()] = criterion.get_score()
num_samples_dict = {train_split_name: num_train_samples}
# other splits
for split_name, split in datasets.items():
if split_name != train_split_name and split_name in scores:
o_scores = scores[split_name]
split_loader = DataLoader(
split,
batch_size=inference_batch_size,
shuffle=False,
)
num_samples = local_inference(
model,
split_loader,
scores=o_scores,
device=device,
)
metrics_dict[split_name] = {
name: score.get_score() for name, score in o_scores.items()
}
num_samples_dict[split_name] = num_samples
# return optimized model parameters and number of train samples
return dict(
local_params=vectorize_module(model),
num_steps=num_steps,
diverged=diverged,
num_samples=num_samples_dict,
metrics=metrics_dict,
)
[docs] def receive_from_client(
server_storage,
client_id,
client_msg,
train_split_name,
serial_aggregator,
appendix_aggregator,
):
return serial_aggregation(
server_storage, client_id, client_msg, train_split_name, serial_aggregator
)
[docs] def optimize(server_storage, serial_aggregator, appendix_aggregator):
if "local_params" in serial_aggregator:
param_avg = serial_aggregator.pop("local_params")
optimizer = server_storage.read("optimizer")
lr_scheduler = server_storage.read("lr_scheduler")
cloud_params = server_storage.read("cloud_params")
pseudo_grads = cloud_params.data - param_avg
# update cloud params
optimizer.zero_grad()
cloud_params.grad = pseudo_grads
optimizer.step()
if lr_scheduler is not None:
lr_scheduler.step()
# purge aggregated results
del param_avg
return serial_aggregator.pop_all()
[docs] def deploy(server_storage):
return dict(avg=server_storage.read("cloud_params"))
[docs] def report(
server_storage,
dataloaders,
rounds,
scores,
metric_logger,
device,
optimize_reports,
deployment_points=None,
):
model = server_storage.read("model")
scores_from_deploy = dict()
# TODO: reporting norm and similar scores should be implemented
# through hooks (hook probe perhaps)
norm_report_freq = 50
norm_reports = dict()
if deployment_points is not None:
for point_name, point in deployment_points.items():
# copy cloud params to cloud model to send to the client
initialize_module(model, point, clone=True, detach=True)
for split_name, loader in dataloaders.items():
if split_name in scores:
split_scores = scores[split_name]
_ = local_inference(
model,
loader,
scores=split_scores,
device=device,
)
split_score_results = {
f"server.{point_name}.{split_name}."
f"{score_name}": score.get_score()
for score_name, score in split_scores.items()
}
scores_from_deploy = {
**scores_from_deploy,
**split_score_results,
}
if rounds % norm_report_freq == 0:
norm_reports[
f"server.{point_name}.param.norm"
] = point.norm().item()
return {**scores_from_deploy, **optimize_reports, **norm_reports}