Source code for streamer.models.model

import torch, os
import torch.nn as nn
import torch.nn.functional as F
from streamer.models.layer import StreamerLayer, StreamerLayerArguments
from dataclasses import dataclass, asdict

[docs] @dataclass class StreamerModelArguments(): log_base_every: int = 1000 r"""Tensorboard log of the base layer every 'log_base_every'""" main: bool = True r"""Main process/gpu or not""" max_layers: int = 3 r"""The maximum number of layers to stack""" feature_dim: int = 1024 r"""Feature dimension of the model embeddings""" evolve_every: int = 50000 r"""Create/stack a new layer every 'evolve_every' """ buffer_size: int = 10 r"""Maximum input buffer size to be used""" force_fixed_buffer: bool = False r"""Force the buffer to be fixed (not replacing inputs) by triggering a boundary when buffer is full""" loss_threshold: float = 0.25 r"""Loss threshold value. Not used in average demarcation mode""" lr: float = 1e-4 r"""Learning rate to be used in all modules""" init_layers: int = 1 r"""How many layers to initialize before training""" init_ckpt: str = '' r"""the path of the pretrained weights, if any""" ckpt_dir: str = '' r"""the path to save weights""" snippet_size: float = 0.5 r"""Snippet size of input video (seconds/image). Typically 0.5 seconds per image""" demarcation_mode: str = 'average' r"""Demarcation mode used to detect boundaries""" distance_mode: str = 'distance' r"""Distance mode for loss calculation""" force_base_dist: bool = True r"""Force the lowest layer to use MSE instead of Cosine Similarity""" window_size: int = 50 r"""Window size for average demarcation mode""" modifier_type: str = 'multiply' r"""Modifier type to apply to average demarcation mode ['multiply', 'add']""" modifier: float = 1.0 r"""Modifier to apply to avrrage demarcation mode""" @staticmethod def from_args(args): return StreamerModelArguments( init_ckpt=args.init_ckpt, init_layers=args.init_layers, ckpt_dir=args.ckpt_dir, snippet_size=args.snippet_size, log_base_every=args.log_base_every, main=args.main, max_layers=args.max_layers, feature_dim=args.feature_dim, evolve_every=args.evolve_every, buffer_size=args.buffer_size, force_fixed_buffer=args.force_fixed_buffer, loss_threshold=args.loss_threshold, lr=args.lr, demarcation_mode=args.demarcation_mode, distance_mode=args.distance_mode, force_base_dist=args.force_base_dist, window_size=args.window_size, modifier_type=args.modifier_type, modifier=args.modifier, ) @staticmethod def from_ckpt_args(args): return StreamerModelArguments( init_ckpt=args.init_ckpt, init_layers=args.init_layers, ckpt_dir=args.ckpt_dir, snippet_size=args.snippet_size, log_base_every=args.log_base_every, main=args.main, max_layers=args.max_layers, feature_dim=args.feature_dim, evolve_every=args.evolve_every, buffer_size=args.buffer_size, force_fixed_buffer=args.force_fixed_buffer, loss_threshold=args.loss_threshold, lr=args.lr, demarcation_mode=args.demarcation_mode, distance_mode=args.distance_mode, force_base_dist=args.force_base_dist, window_size=args.window_size, modifier_type=args.modifier_type, modifier=args.modifier, fp_up=args.fp_up, fp_down=args.fp_down )
[docs] class StreamerModel(nn.Module): r""" The implementation of the STREAMER model for training. This class initializes the first layer(s) of streamer, saves/loads weights, etc. :param StreamerModelArguments args: The arguments passed to Streamer Model :param Logger logger: The tensorboard logger class :param torch.nn.Module encoder: The encoder model (e.g., :py:class:`~streamer.models.networks.CNNEncoder`) :param torch.nn.Module decoder: The decoder model (e.g., :py:class:`~streamer.models.networks.CNNDecoder`) """ def __init__(self, args:StreamerModelArguments, logger=None, encoder=None, decoder=None): super(StreamerModel, self).__init__() self.args = args self.logger = logger self.streamer = None self.models_saved = 0 # counters self.base_counter = 0 self.__initialize(args.init_ckpt, args.init_layers, encoder=encoder, decoder=decoder) def init_layer(self, count=1, encoder=None, decoder=None): r""" Initializes the Streamer layer(s) and passes the enocder/decoder to the first layer. :param int count: How many layers to create :param torch.nn.Module encoder: The encoder model (e.g., :py:class:`~streamer.models.networks.CNNEncoder`) :param torch.nn.Module decoder: The decoder model (e.g., :py:class:`~streamer.models.networks.CNNDecoder`) """ # Arguments self.args.reps_fn = self.getReps streamerLayerArgs = StreamerLayerArguments.from_model_args(self.args) self.streamer = StreamerLayer( args = self.args, layer_num = 0, init_count = count, encoder=encoder, decoder=decoder, logger=self.logger ) def __initialize(self, ckpt='', count=1, encoder=None, decoder=None): r""" Initializes the Streamer layer(s) with checkpoint if available. :param str ckpt: pretrained weights location :param int count: How many layers to create :param torch.nn.Module encoder: The encoder model (e.g., :py:class:`~streamer.models.networks.CNNEncoder`) :param torch.nn.Module decoder: The decoder model (e.g., :py:class:`~streamer.models.networks.CNNDecoder`) """ if ckpt=='': self.init_layer(count=count, encoder=encoder, decoder=decoder) else: ckpt= torch.load(ckpt, map_location='cpu') self.init_layer(count=ckpt['num_layers'], encoder=encoder, decoder=decoder) self.streamer.load_state_dict(ckpt['weights'])
[docs] def forward(self, x): r""" Forward propagation function that calls the :py:meth:`~streamer.models.layer.StreamerLayer.forward` function of the first :py:class:`~streamer.models.layer.StreamerLayer`. :param torch.Tensor x: the input image [1, 3, H, W] """ # log base signal if self.logger != None: self.logger.model(self.streamer.context, x, self.streamer.attn_img, self.streamer.attns) # main forward self.streamer(x, base_counter=self.base_counter) self.base_counter += 1
[docs] def getReps(self, layer_num): r""" Aggregates the representations from all the layers. :param int layer_num: the index of the calling layer :returns: (*torch.Tensor*): concatenated representations from all layers [L, feature_dim] """ curr_layer_num = 0 ll = self.streamer reps = [ll.representation] while ll.above != None and ll.above.representation != None: curr_layer_num += 1 ll = ll.above reps.append(ll.representation) return torch.cat(reps, dim=0)
def extract_rep(self): r""" Extract representation function used for logging json hierarchy. Calls recursive function :py:meth:`~streamer.models.layer.StreamerLayer.extract_rep` on the :py:class:`~streamer.models.layer.StreamerLayer` class :returns: (*dict*): hierarchy represented as boundaries of every layer """ if self.streamer == None: return {} hierarchy = dict(boundaries = {}) self.streamer.extract_rep(hierarchy) # close open boundaries for layer_name, layer_bounds in hierarchy['boundaries'].items(): if layer_name == 0: continue if layer_bounds[-1] != hierarchy['boundaries'][0][-1]: layer_bounds.append(hierarchy['boundaries'][0][-1]) return hierarchy
[docs] def reset_model(self): r""" Resets the whole streamer model for a new video. Calls recursive function :py:meth:`~streamer.models.layer.StreamerLayer.reset_layer` on the :py:class:`~streamer.models.layer.StreamerLayer` class """ if self.streamer == None: return self.streamer.reset_layer() self.base_counter = 0
def get_num_layers(self): r""" Get the total number of layers. Calls recursive function :py:meth:`~streamer.models.layer.StreamerLayer.get_num_layers` on the :py:class:`~streamer.models.layer.StreamerLayer` class """ if self.streamer == None: return 0 return self.streamer.get_num_layers(0)
[docs] def optimize_model(self): r""" Optimizes the whole streamer model (gradient step). Calls recursive function :py:meth:`~streamer.models.layer.StreamerLayer.optimize_layer` on the :py:class:`~streamer.models.layer.StreamerLayer` class """ if self.streamer != None: self.streamer.optimize_layer()
[docs] def save_model(self): r""" Saves the model weights to :py:class:`~StreamerModelArguments.ckpt_dir` """ if self.streamer == None or self.args.main == False: return ckpt = dict( model_args = asdict(self.args), model_modality = self.modality, model_snippet_size = self.args.snippet_size, num_layers = self.get_num_layers(), weights = self.streamer.state_dict()) torch.save(ckpt, os.path.join(self.args.ckpt_dir, f'model_{str(self.models_saved).zfill(3)}.pth')) self.models_saved += 1