Models#
Backbone Encoder#
from kagu.models import BackboneArguments, Backbone
backbone_args = BackboneArguments(backbone='inception',
backbone_pretrained=True,
backbone_frozen=True,
)
backbone = Backbone(args=backbone_args)
- class kagu.models.model.BackboneArguments(backbone: str, backbone_pretrained: bool, backbone_frozen: bool)[source]#
Arguments for the Backbone
- backbone: str#
The backbone architecture
- backbone_pretrained: bool#
Use pretrained backbone weights
- backbone_frozen: bool#
Freeze the backbone weights
- class kagu.models.model.Backbone(args: BackboneArguments)[source]#
The backbone encoder model.
This model receives input images and returns feature vector grid.
- Parameters:
args (BackboneArguments) – The parameters used for the Backbone Model
Kagu Model#
from kagu.models import KaguModelArguments, KaguModel
kagu_args = BackboneArguments(
predictor = 'lstm'
dropout= 0.4,
teacher= True,
step= 8,
)
kagu_model = KaguModel(args=kagu_args)
- class kagu.models.model.KaguModelArguments(predictor: str, dropout: float, teacher: bool, step: int)[source]#
Arguments for the Kagu model
- predictor: str#
The predictor type
- dropout: float#
The dropout rate
- teacher: bool#
Enable Teacher forcing
- step: int#
The stride by which the dataset is streamed
- class kagu.models.model.KaguModel(args: KaguModelArguments)[source]#
The Implementation of the Kagu model for training.
- Parameters:
args (KaguModelArguments) – The arguments passed to the Kagu Model
- forward(x, hidden, p)[source]#
The forward propagation function of the main Kagu model.
- Parameters:
x (torch.Tensor) – tensor of shape [snippet, h*w, 2048]
hidden (tuple(torch.Tensor, torch.Tensor)) – The hidden state is a tuple of tensors (each of size [h*w, 2048]) if predictor is LSTM
p (torch.Tensor) – previous prediction for teacher forcing. Shape [h*w, 2048]
- Returns:
(torch.Tensor): Prediction of the model [snippet, h*w, 2048]
(torch.Tensor): attention grid [8, h*w, 1]
(tuple(torch.Tensor, torch.Tensor)): hidden states of the LSTM.
(torch.Tensor): P_out for teacher forcing