Trainer#
- class kagu.trainers.trainer.TrainerArguments(tb_dir: str, files_dir: str, backbone: str, predictor: str, snippet: int, step: int)[source]#
Arguments for Trainer class
- tb_dir: str#
The path to the tensorboard logging directory
- files_dir: str#
The path to the files logging directory
- backbone: str#
The backbone architecture
- predictor: str#
The predictor type
- snippet: int#
Number of frames to process
- step: int#
The stride by which we process the frames. Same as snippet if not overlapping
- class kagu.trainers.trainer.Trainer(args: TrainerArguments, loader, backbone, model, loss, optimizer)[source]#
The trainer class: takes care of the training loop, logging and iterating over the dataset.
- Parameters:
args (TrainerArguments) – The arguments passed to Trainer class
loader (KaguDataset) – The dataset class to be iterated over.
backbone (torch.nn.Module) – The backbone encoder model.
loss (torch.nn.Module) – The loss module for calculating the prediction loss.
optimizer (torch.nn.Module) – The Adam optimizer used for stepping gradients.