Trainer#

class kagu.trainers.trainer.TrainerArguments(tb_dir: str, files_dir: str, backbone: str, predictor: str, snippet: int, step: int)[source]#

Arguments for Trainer class

tb_dir: str#

The path to the tensorboard logging directory

files_dir: str#

The path to the files logging directory

backbone: str#

The backbone architecture

predictor: str#

The predictor type

snippet: int#

Number of frames to process

step: int#

The stride by which we process the frames. Same as snippet if not overlapping

class kagu.trainers.trainer.Trainer(args: TrainerArguments, loader, backbone, model, loss, optimizer)[source]#

The trainer class: takes care of the training loop, logging and iterating over the dataset.

Parameters:
  • args (TrainerArguments) – The arguments passed to Trainer class

  • loader (KaguDataset) – The dataset class to be iterated over.

  • backbone (torch.nn.Module) – The backbone encoder model.

  • loss (torch.nn.Module) – The loss module for calculating the prediction loss.

  • optimizer (torch.nn.Module) – The Adam optimizer used for stepping gradients.

train()[source]#

The training function. Called once after instantiating the trainer