Source code for fedsim.models.utils

"""
Model Utils
-----------

"""

import torch


[docs]class ModelReconstructor(torch.nn.Module): """reconstructs a model out of a feature_extractor and a classifier. Args: feature_extractor (Module): feature-extractor module classifier (Module): classifier module connection_fn (Callable, optional): optional connection function to apply on the output of feature-extractor before feeding to the classifier. Defaults to None. """ def __init__(self, feature_extractor, classifier, connection_fn=None) -> None: super(ModelReconstructor, self).__init__() self.feature_extractor = feature_extractor self.classifier = classifier self.connection_fn = connection_fn
[docs] def forward(self, input): features = self.feature_extractor(input) if self.connection_fn is not None: features = self.connection_fn(features) return self.classifier(features)
[docs]def get_output_size(in_size, pad, kernel, stride): """Calculates the output size after applying a kernel (for one dimension). Args: in_size (int): input size. pad (int): padding size. If set to ``same``, input size is directly returned. kernel (int): kernel size. stride (int): size of strides. Returns: int: output size """ if pad == "same": return in_size return ((in_size + 2 * pad - kernel) // stride) + 1