Source code for fedsim.distributed.centralized.training.utils

"""
Distributed Centralized Trainign Utils
--------------------------------------
"""


[docs]def serial_aggregation( server_storage, client_id, client_msg, train_split_name, aggregator, train_weight=None, other_weight=None, purge_msg=True, ): """To serially aggregate received message from a client Args: server_storage (Storage): server storage object client_id (int): client id. client_msg (Mapping): client message. train_split_name (str): name of the training split on clients aggregator (SerialAggregator): a serial aggregator to accumulate info. train_weight (float, optional): aggregation weight for trianing parameters. If not specified, uses sample number. Defaults to None. other_weight (float, optional): aggregation weight for any other factor/metric. If not specified, uses sample number. Defaults to None. Returns: bool: success of aggregation. """ params = client_msg["local_params"].clone().detach().data diverged = client_msg["diverged"] metrics = client_msg["metrics"] n_samples = client_msg["num_samples"] if diverged: return False if train_weight is None: train_weight = n_samples[train_split_name] if train_weight > 0: aggregator.add("local_params", params, train_weight) for split_name, metrics in metrics.items(): if other_weight is None: other_weight = n_samples[split_name] for key, metric in metrics.items(): aggregator.add(f"clients.{split_name}.{key}", metric, other_weight) # purge client info if purge_msg: del client_msg return True