Source code for fedsim.utils.convert_parameters

r"""
Parameters Conversion
---------------------
"""
from collections import OrderedDict

import torch
from torch import Tensor
from torch.nn import Module
from torch.nn.utils import parameters_to_vector
from torch.nn.utils import vector_to_parameters
from torch.nn.utils.convert_parameters import _check_param_device


[docs]def vector_to_parameters_like(vec, parameters_like): r"""Convert one vector to new parameters like the ones provided Args: vec (Tensor): a single vector represents the parameters of a model. parameters (Iterable[Tensor]): an iterator of Tensors that are the parameters of a model. This is only used to get the sizes. New parametere are defined. """ # Ensure vec of type Tensor if not isinstance(vec, torch.Tensor): raise TypeError( "expected torch.Tensor, but got: {}".format(torch.typename(vec)) ) # Pointer for slicing the vector for each parameter pointer = 0 new_params = [] for param in parameters_like: # The length of the parameter num_param = param.numel() # Slice the vector, reshape it, and replace the old data of the # parameter new_params.append(vec[pointer : pointer + num_param].view_as(param).data) # Increment the pointer pointer += num_param return new_params
[docs]def vector_to_named_parameters_like( vec: Tensor, named_parameters_like: OrderedDict, ) -> OrderedDict: r"""Convert one vector to new named parameters like the ones provided Args: vec (Tensor): a single vector represents the parameters of a model. parameters (OrderedDict): a dictioanry of Tensors that are the parameters of a model. This is only used to get the sizes and keys. New parametere are defined. """ # Ensure vec of type Tensor if not isinstance(vec, torch.Tensor): raise TypeError( "expected torch.Tensor, but got: {}".format(torch.typename(vec)) ) # Pointer for slicing the vector for each parameter pointer = 0 new_params = OrderedDict() for key, param in named_parameters_like: # The length of the parameter num_param = param.numel() # Slice the vector, reshape it, and replace the old data of the # parameter new_params[key] = vec[pointer : pointer + num_param].view_as(param) # Increment the pointer pointer += num_param return new_params
[docs]def vectorize_module(module: Module, clone=True, detach=True): r"""convert parameters of a module to a vector Args: module (Module): module to convert the parameters of clone (bool, optional): clones the output. Defaults to True. detach (bool, optional): detaches the output. Defaults to True. Returns: Module: 1-D Tensor of all parameters in the module """ vec = parameters_to_vector(module.parameters()) if clone: vec = vec.clone() if detach: vec = vec.detach() return vec
[docs]def vectorize_module_grads(module: Module, clone=True, detach=True): r"""convert parameters gradients of a module to a vector Args: module (Module): module to convert the parameters of clone (bool, optional): clones the output. Defaults to True. detach (bool, optional): detaches the output. Defaults to True. Returns: Module: 1-D Tensor of gradients of all parameters in the module. None if at least grad of one children deos not exist. """ param_device = None vec = [] for param in module.parameters(): # Ensure the parameters are located in the same device param_device = _check_param_device(param, param_device) if param.grad is None: return None else: vec.append(param.grad.view(-1)) vec = torch.cat(vec) if clone: vec = vec.clone() if detach: vec = vec.detach() return vec
[docs]def initialize_module(module: Module, vec: Tensor, clone=True, detach=True): r"""initializes a module's parameters with a 1-D vector Args: module (Module): module to initialize weights vec (Tensor): a 1-D Tensor clone (bool, optional): clones the vector before initilization. Defaults to True. detach (bool, optional): detaches the output before the initialization. Defaults to True. """ if clone: vec = vec.clone() if detach: vec = vec.detach() if len(vectorize_module(module)) != len(vec): return False vector_to_parameters(vec, module.parameters()) return True