r"""
FedDyn
-------
"""
from functools import partial
import torch
from fedsim.local.training.step_closures import default_step_closure
from fedsim.utils import vector_to_parameters_like
from fedsim.utils import vectorize_module
from .fedavg import FedAvg
from .utils import serial_aggregation
[docs]class FedDyn(FedAvg):
r"""Implements FedDyn algorithm for centralized FL.
For further details regarding the algorithm we refer to `Federated Learning Based
on Dynamic Regularization`_.
Args:
data_manager (``distributed.data_management.DataManager``): data manager
metric_logger (``logall.Logger``): metric logger for tracking.
num_clients (int): number of clients
sample_scheme (``str``): mode of sampling clients. Options are ``'uniform'``
and ``'sequential'``
sample_rate (``float``): rate of sampling clients
model_def (``torch.Module``): definition of for constructing the model
epochs (``int``): number of local epochs
criterion_def (``Callable``): loss function defining local objective
optimizer_def (``Callable``): derfintion of server optimizer
local_optimizer_def (``Callable``): defintoin of local optimizer
lr_scheduler_def (``Callable``): definition of lr scheduler of server optimizer.
local_lr_scheduler_def (``Callable``): definition of lr scheduler of local
optimizer
r2r_local_lr_scheduler_def (``Callable``): definition to schedule lr that is
delivered to the clients at each round (deterimined init lr of the
client optimizer)
batch_size (int): batch size of the local trianing
test_batch_size (int): inference time batch size
device (str): cpu, cuda, or gpu number
alpha (float): FedDyn's :math:`\alpha` hyper-parameter for local regularization
.. note::
definition of
* learning rate schedulers, could be any of the ones defined at
``torch.optim.lr_scheduler`` or any other that implements step and
get_last_lr methods._schedulers``.
* optimizers, could be any ``torch.optim.Optimizer``.
* model, could be any ``torch.Module``.
* criterion, could be any ``fedsim.scores.Score``.
.. _Federated Learning Based on Dynamic Regularization:
https://openreview.net/forum?id=B7v4QMR6Z9w
"""
[docs] def init(server_storage, *args, **kwrag):
default_alpha = 0.1
local_train_split_name = "train"
FedAvg.init(server_storage)
cloud_params = server_storage.read("cloud_params")
server_storage.write("avg_params", cloud_params.clone().detach())
server_storage.write("h", torch.zeros_like(cloud_params))
server_storage.write("average_sample", 0)
server_storage.write("alpha", kwrag.get("alpha", default_alpha))
# oracle read violation, num_clients read violation
print("Warning: private access violation")
print("\t", end="")
print(
"FedDyn assumes prior knowledge on the number of clients and oracle samples"
)
oracle_dataset = server_storage.read("oracle_dataset")[local_train_split_name]
num_clients = server_storage.read("num_clients")
average_sample = len(oracle_dataset) / num_clients
server_storage.write("average_sample", average_sample)
[docs] def send_to_client(server_storage, client_id):
msg = FedAvg.send_to_client(server_storage, client_id)
msg["average_sample"] = server_storage.read("average_sample")
msg["alpha"] = server_storage.read("alpha")
return msg
[docs] def send_to_server(
id,
rounds,
storage,
datasets,
train_split_name,
metrics,
epochs,
criterion,
train_batch_size,
inference_batch_size,
optimizer_def,
lr_scheduler_def=None,
device="cuda",
ctx=None,
step_closure=None,
):
model = ctx["model"]
alpha = ctx["alpha"]
average_sample = ctx["average_sample"]
params_init = vectorize_module(model, clone=True, detach=True)
h = storage.read("h")
alpha_adaptive = alpha / len(datasets[train_split_name]) * average_sample
def transform_grads_fn(model):
params = vectorize_module(model)
grad_additive = 0.5 * (params - params_init)
if h is not None:
grad_additive -= h
grad_additive_list = vector_to_parameters_like(
alpha_adaptive * grad_additive, model.parameters()
)
for p, g_a in zip(model.parameters(), grad_additive_list):
p.grad += g_a
step_closure_ = partial(
default_step_closure, transform_grads=transform_grads_fn
)
opt_res = FedAvg.send_to_server(
id,
rounds,
storage,
datasets,
train_split_name,
metrics,
epochs,
criterion,
train_batch_size,
inference_batch_size,
optimizer_def,
lr_scheduler_def,
device,
ctx,
step_closure=step_closure_,
)
# update local h
pseudo_grads = params_init - vectorize_module(model, clone=True, detach=True)
new_h = pseudo_grads if h is None else pseudo_grads + h
storage.write("h", new_h)
return opt_res
[docs] def receive_from_client(
server_storage,
client_id,
client_msg,
train_split_name,
serial_aggregator,
appendix_aggregator,
):
weight = 1
return serial_aggregation(
server_storage,
client_id,
client_msg,
train_split_name,
serial_aggregator,
train_weight=weight,
)
[docs] def optimize(server_storage, serial_aggregator, appendix_aggregator):
if "local_params" in serial_aggregator:
weight = serial_aggregator.get_weight("local_params")
param_avg = serial_aggregator.pop("local_params")
optimizer = server_storage.read("optimizer")
lr_scheduler = server_storage.read("lr_scheduler")
cloud_params = server_storage.read("cloud_params")
pseudo_grads = cloud_params.data - param_avg
h = server_storage.read("h")
num_clients = server_storage.read("num_clients")
# already warned, disable read protection from num_clients
server_storage.change_protection("num_clients", False, False)
h = h + weight / num_clients * pseudo_grads
new_params = param_avg - h
modified_pseudo_grads = cloud_params.data - new_params
# update cloud params
optimizer.zero_grad()
cloud_params.grad = modified_pseudo_grads
optimizer.step()
if lr_scheduler is not None:
lr_scheduler.step()
server_storage.write("avg_params", param_avg.detach().clone())
server_storage.write("h", h.data)
# purge aggregated results
del param_avg
return serial_aggregator.pop_all()
[docs] def deploy(server_storage):
return dict(
cloud=server_storage.read("cloud_params"),
avg=server_storage.read("avg_params"),
)