Source code for fedsim.local.training.step_closures
"""
Step Closures
-------------
"""
from torch.nn.utils import clip_grad_norm_
[docs]def default_step_closure(
x,
y,
model,
criterion,
optimizer,
scores,
max_grad_norm=1000,
device="cpu",
transform_grads=None,
transform_y=None,
**kwargs,
):
"""one step of local training including:
* prepare mini batch of the data
* forward pass
* loss calculation
* backward pass
* transfor and modify the gradients
* take optimization step
* evaluate scores on the training mini-batch batch.
Args:
x (Tensor): inputs
y (Tensor): labels
model (Module): model
criterion (Callable): loss criterion
optimizer (Optimizer): optimizer chosen and instanciated from classes under
``torch.optim``.
scores: Dict[str, Score]: dictionary of form str: Score to evaluate at the end
of the closure.
max_grad_norm (int, optional): to clip the norm of the gradients.
Defaults to 1000.
device (str, optional): device to load the data into
("cpu", "cuda", or device ordinal number). This must be the same device as
the one model parameters are loaded into. Defaults to "cpu".
transform_grads (Callable, optional): A function the takes the model and
modified the gradients of the parameters. Defaults to None.
transform_y (Callable, optional): a function that takes raw labels and modifies
them. Defaults to None.
Returns:
Tensor: loss value obtained from the forward pass.
"""
if transform_y is not None:
y = transform_y(y)
x = x.to(device)
y = y.reshape(-1).long()
y = y.to(device)
model.train()
outputs = model(x)
loss = criterion(outputs, y)
if loss.isnan() or loss.isinf():
return loss
# backpropagation
loss.backward()
if transform_grads is not None:
transform_grads(model)
# Clip gradients
clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)
# optimize
optimizer.step()
optimizer.zero_grad()
if scores is not None:
for score in scores.values():
score(outputs, y)
return loss