Source code for fedsim.local.training.training
"""
Local Training
--------------
Training for local client
"""
import inspect
from .step_closures import default_step_closure
[docs]def local_train(
model,
train_data_loader,
epochs,
steps,
criterion,
optimizer,
lr_scheduler=None,
device="cpu",
step_closure=default_step_closure,
scores=None,
max_grad_norm=1000,
**step_ctx,
):
"""local training
Args:
model (Module): model to use for getting the predictions.
train_data_loader (Iterable): trianing data loader.
epochs (int): number of local epochs.
steps (int): number of optimization epochs after the final epoch.
criterion (Callable): loss criterion.
optimizer (Optimizer): a torch optimizer.
lr_scheduler (Any, optional): a torch Learning rate scheduler. Defaults to None.
device (str, optional): device to load the data into
("cpu", "cuda", or device ordinal number). This must be the same device as
the one model parameters are loaded into. Defaults to "cpu".
step_closure (Callable, optional): step closure for an optimization step.
Defaults to default_step_closure.
scores (Dict[str, Score], optional): a dictionary of str:Score.
Defaults to None.
max_grad_norm (int, optional): to clip the norm of the gradients.
Defaults to 1000.
Returns:
Tuple[int, int, bool]: tuple of number of training samples,
number of optimization steps, divergence.
"""
if steps > 0:
# this is because we break out of the epoch loop, so we need an
# additional iteration to go over extra steps
epochs += 1
# instantiate control variables
num_steps = 0
diverged = False
all_loss = 0
num_train_samples = 0
if train_data_loader is not None:
# iteration over epochs
for _ in range(epochs):
if diverged:
break
# iteration over mini-batches
epoch_step_cnt = 0
for x, y in train_data_loader:
# send the mini-batch to device
# calculate the local objective's loss
loss = step_closure(
x,
y,
model,
criterion,
optimizer,
scores,
max_grad_norm,
device=device,
**step_ctx,
)
if loss.isnan() or loss.isinf():
del loss
diverged = True
break
# update control variables
epoch_step_cnt += 1
num_steps += 1
num_train_samples += y.shape[0]
all_loss += loss.item()
if lr_scheduler is not None:
step_args = inspect.signature(lr_scheduler.step).parameters
if "metrics" in step_args:
comb_scores = {**scores, **{criterion.get_name(): criterion}}
trigger_metric = lr_scheduler.trigger_metric
if trigger_metric not in comb_scores:
raise Exception(
f"{trigger_metric} not in local scores. "
f"Possible options are {comb_scores.keys()}"
)
lr_scheduler.step(comb_scores[trigger_metric].get_score())
else:
lr_scheduler.step()
return (
num_train_samples,
num_steps,
diverged,
)