Source code for fedsim.local.training.inference
"""
Local Inference
---------------
Inference for local client
"""
import torch
[docs]def local_inference(
model,
data_loader,
scores,
device="cpu",
transform_y=None,
):
"""to test the performance of a model on a test set.
Args:
model (Module): model to get the predictions from
data_loader (Iterable): inference data loader.
scores (Dict[str, Score]): scores to evaluate
device (str, optional): device to load the data into
("cpu", "cuda", or device ordinal number). This must be the same device as
the one model parameters are loaded into. Defaults to "cpu".
transform_y (Callable, optional): a function that takes raw labels and modifies
them. Defaults to None.
Returns:
int: number of samples the evaluation is done for.
"""
num_samples = 0
model_is_training = model.training
model.eval()
with torch.no_grad():
for (X, y) in data_loader:
if transform_y is not None:
y = transform_y(y)
y = y.reshape(-1).long()
y = y.to(device)
X = X.to(device)
outputs = model(X)
num_samples += len(y)
for score in scores.values():
score(outputs, y)
del outputs
if model_is_training:
model.train()
return num_samples