Models#

Backbone Encoder#

from kagu.models import BackboneArguments, Backbone

backbone_args = BackboneArguments(backbone='inception',
                               backbone_pretrained=True,
                               backbone_frozen=True,
                               )
backbone = Backbone(args=backbone_args)
class kagu.models.model.BackboneArguments(backbone: str, backbone_pretrained: bool, backbone_frozen: bool)[source]#

Arguments for the Backbone

backbone: str#

The backbone architecture

backbone_pretrained: bool#

Use pretrained backbone weights

backbone_frozen: bool#

Freeze the backbone weights

class kagu.models.model.Backbone(args: BackboneArguments)[source]#

The backbone encoder model.

This model receives input images and returns feature vector grid.

Parameters:

args (BackboneArguments) – The parameters used for the Backbone Model

forward(x)[source]#

The forward propagation function that takes input image and returns a grid of output vectors

Parameters:

x (torch.Tensor) – tensor of shape [Snippet, 3, H, W]

Returns:

  • (torch.Tensor): feature vector of shape [Snippet, h*w, 2048]


Kagu Model#

from kagu.models import KaguModelArguments, KaguModel

kagu_args = BackboneArguments(
                        predictor = 'lstm'
                        dropout= 0.4,
                        teacher= True,
                        step= 8,
                               )
kagu_model = KaguModel(args=kagu_args)
class kagu.models.model.KaguModelArguments(predictor: str, dropout: float, teacher: bool, step: int)[source]#

Arguments for the Kagu model

predictor: str#

The predictor type

dropout: float#

The dropout rate

teacher: bool#

Enable Teacher forcing

step: int#

The stride by which the dataset is streamed

class kagu.models.model.KaguModel(args: KaguModelArguments)[source]#

The Implementation of the Kagu model for training.

Parameters:

args (KaguModelArguments) – The arguments passed to the Kagu Model

forward(x, hidden, p)[source]#

The forward propagation function of the main Kagu model.

Parameters:
  • x (torch.Tensor) – tensor of shape [snippet, h*w, 2048]

  • hidden (tuple(torch.Tensor, torch.Tensor)) – The hidden state is a tuple of tensors (each of size [h*w, 2048]) if predictor is LSTM

  • p (torch.Tensor) – previous prediction for teacher forcing. Shape [h*w, 2048]

Returns:

  • (torch.Tensor): Prediction of the model [snippet, h*w, 2048]

  • (torch.Tensor): attention grid [8, h*w, 1]

  • (tuple(torch.Tensor, torch.Tensor)): hidden states of the LSTM.

  • (torch.Tensor): P_out for teacher forcing