Local Training#

Training for local client

local_train(model, train_data_loader, epochs, steps, criterion, optimizer, lr_scheduler=None, device='cpu', step_closure=<function default_step_closure>, scores=None, max_grad_norm=1000, **step_ctx)[source]#

local training

  • model (Module) -- model to use for getting the predictions.

  • train_data_loader (Iterable) -- trianing data loader.

  • epochs (int) -- number of local epochs.

  • steps (int) -- number of optimization epochs after the final epoch.

  • criterion (Callable) -- loss criterion.

  • optimizer (Optimizer) -- a torch optimizer.

  • lr_scheduler (Any, optional) -- a torch Learning rate scheduler. Defaults to None.

  • device (str, optional) -- device to load the data into ("cpu", "cuda", or device ordinal number). This must be the same device as the one model parameters are loaded into. Defaults to "cpu".

  • step_closure (Callable, optional) -- step closure for an optimization step. Defaults to default_step_closure.

  • scores (Dict[str, Score], optional) -- a dictionary of str:Score. Defaults to None.

  • max_grad_norm (int, optional) -- to clip the norm of the gradients. Defaults to 1000.


Tuple[int, int, bool] --

tuple of number of training samples,

number of optimization steps, divergence.