import torch
import torch.nn as nn
import torch.nn.functional as F
from streamer.utils.events import CounterDetector
from streamer.utils.events import Patcher
from torch.utils.tensorboard import SummaryWriter
from streamer.utils.logging import TBWriter
from collections import deque
import streamer.models.networks as networks
from dataclasses import dataclass
from streamer.core.demarcation import EventDemarcation
from streamer.core.loss import StreamerLoss
from streamer.core.buffer import MemBuffer
import types
[docs]
@dataclass
class StreamerLayerArguments():
# architecture arguments
max_layers: int
r"""The maximum number of layers to stack"""
feature_dim: int
r"""Feature dimension of the model embeddings"""
evolve_every: int
r"""Create/stack a new layer every 'evolve_every' """
buffer_size: int
r"""Maximum input buffer size to be used"""
loss_threshold: float
r"""Loss threshold value. Not used in average demarcation mode"""
lr: float
r"""Learning rate to be used in all modules"""
reps_fn: types.FunctionType
r"""Function to aggregate representations from all layers"""
snippet_size: float
r"""Snippet size of input video (seconds/image). Typically 0.5 seconds per image"""
demarcation_mode: str = 'average'
r"""Demarcation mode used to detect boundaries"""
distance_mode: str = 'similarity'
r"""Distance mode for loss calculation"""
force_base_dist: bool = False
r"""Force the lowest layer to use MSE instead of Cosine Similarity"""
window_size: int = 50
r"""Window size for average demarcation mode"""
modifier_type: str = 'multiply'
r"""Modifier type to apply to average demarcation mode ['multiply', 'add']"""
modifier: float = 1.0
r"""Modifier to apply to average demarcation mode"""
force_fixed_buffer: bool = False
r"""Force the buffer to be fixed (not replacing inputs) by triggering a boundary when buffer is full"""
@staticmethod
def from_model_args(args):
return StreamerLayerArguments(
max_layers=args.max_layers,
feature_dim=args.feature_dim,
evolve_every=args.evolve_every,
buffer_size=args.buffer_size,
force_fixed_buffer=args.force_fixed_buffer,
loss_threshold=args.loss_threshold,
lr=args.lr,
reps_fn=args.reps_fn,
snippet_size=args.snippet_size,
demarcation_mode=args.demarcation_mode,
distance_mode=args.distance_mode,
force_base_dist=args.force_base_dist,
window_size=args.window_size,
modifier_type=args.modifier_type,
modifier=args.modifier,
)
[docs]
class StreamerLayer(nn.Module):
r"""
STREAMER layer implementation.
This layer can:
* Create/stack other :py:class:`StreamerLayer` layers recursively for a maximum of :py:class:`~StreamerLayerArguments.max_layers`
* Call the other :py:class:`StreamerLayer` layers by propagating current representation
* Calculate and store the loss for the :py:class:`~streamer.optimizers.streamer_optimizer.StreamerOptimizer` to use it
:param StreamerLayerArguments args: arguments provided to every streamer layer
:param int layer_num: the index of the current layer in the layers stack
:param int init_count: used to create more layers at initialization. Useful for Inference model using pretrained weights.
:param torch.nn.Module encoder: Encoder class to be used at this layer. Passed later to the :py:class:`~streamer.models.networks.TemporalEncoding` module
:param torch.nn.Module decoder: Decoder class to be used at this layer. Passed later to the :py:class:`~streamer.models.networks.HierarchicalPrediction` module
:param Logger logger: Logger to be used for tensorboard.
"""
def __init__(self,
args,
layer_num,
init_count,
encoder=None,
decoder=None,
logger=None
):
super(StreamerLayer, self).__init__()
# === References and Counters === #
self.args = args
self.logger = logger
self.layer_num = layer_num
self.init_count = init_count
self.distance_mode = 'distance' if (args.force_base_dist and self.layer_num == 0) else args.distance_mode
# self.global_distance_mode = args.distance_mode
self.preprocess = (encoder != None)
self.above = None
self.layer_creator = CounterDetector(count=self.args.evolve_every)
# === Memory Buffer === #
self.buffer = MemBuffer(args.buffer_size, self.distance_mode)
# === F function === #
self.f = networks.TemporalEncoding(
feature_dim=args.feature_dim,
buffer_size=args.buffer_size,
lr=args.lr,
num_layers=2,
n_heads=8,
encoder=encoder,
patch=False,
)
# === P function === #
self.p = networks.HierarchicalPrediction(
feature_dim=args.feature_dim,
max_layers=args.max_layers,
lr=args.lr,
layer_num=self.layer_num,
num_layers=2,
n_heads=8,
decoder=decoder,
patch=False,
)
# === Loss Function === #
self.l = StreamerLoss(dist_mode = self.distance_mode)
# === Demarcation Function === #
self.demarcation = EventDemarcation(
dem_mode = args.demarcation_mode,
dist_mode = self.distance_mode,
threshold = args.loss_threshold,
window_size = args.window_size,
modifier = args.modifier,
modifier_type = args.modifier_type
)
# reset and clear
self.reset_layer()
print(f'created layer {self.layer_num}')
# create more parent layers at init if needed
self.create_parent(self.init_count>1)
[docs]
def reset_layer(self):
r"""
Reset function to be used at the beginning of a new video.
Recursively applied to every layer.
"""
self.buffer.reset_buffer()
self.demarcation.reset()
self.hierarchy_boundaries = [0]
self.hierarchy_attn = []
self.representation = None
self.attn_img = None
self.attns = None
self.context = None
self.objective_ready = False
self.last_base_counter = -1
if self.above != None: self.above.reset_layer()
def extract_rep(self, hierarchy):
r"""
Extracts the hierarchy for json logging.
Recursively applied to every layer.
:param dict hierarchy: hierarchy dictionary to be filled with boundaries
"""
hierarchy['boundaries'][self.layer_num] = self.hierarchy_boundaries.copy()
if self.above != None:
self.above.extract_rep(hierarchy)
[docs]
def optimize_layer(self):
r"""
Optimization step function. Steps then zeros the gradients.
Calls the `step_params()` and `zero_grad()` functions of every module. (e.g., :py:meth:`~streamer.models.networks.TemporalEncoding.step_params`)
Recursively applied to every layer.
"""
# Step the optimizers
self.f.step_params()
self.p.step_params()
self.l.step_params()
# zero the gradients
self.f.zero_grad()
self.p.zero_grad()
self.l.zero_grad()
# reset representations because parameters have changed
self.representation = None
self.context = None
if self.above != None:
self.above.optimize_layer()
[docs]
def get_num_layers(self, num):
r"""
Recursive function to get the total number of layers
:param int num: current number of layers at the previous layer
:returns:
(*int*): Previous num of layers + 1
"""
if self.above != None:
return self.above.get_num_layers(num+1)
else: return num+1
[docs]
def create_parent(self, create):
r"""
Function to create/stack another :py:class:`StreamerLayer` layer
:param bool create: only add another layer if create is True
"""
if (not self.above) and (create) and (self.layer_num<(self.args.max_layers-1)) and (self.training):
self.above = StreamerLayer(self.args,
self.layer_num+1,
self.init_count-1,
logger=self.logger).cuda()
[docs]
def predict(self):
r"""
Prediction function that calls the :py:class:`~streamer.models.networks.TemporalEncoding` and :py:class:`~streamer.models.networks.HierarchicalPrediction` modules
"""
# === prepare input === #
f_in = torch.cat(self.buffer.get_inputs(), dim=0)
self.attn_img = f_in[0]
# === Temporal Encoding === #
'''
input: [S, 3, 128, 128] or [S, 1024]
outputs: [1, 1024]
'''
self.representation, self.attns = self.f(f_in)
if self.distance_mode=="similarity": self.representation = F.normalize(self.representation, p=2, dim=-1)
# === Hierarchical Prediction === #
'''
input: [L, 1024]
outputs: [1, 3, 128, 128] or [1, 1024]
'''
reps = self.args.reps_fn(self.layer_num)
self.context = self.p(reps)
[docs]
def forward(self, x, base_counter):
r"""
Forward propagation function for a layer. Recursively calls the layer above at event boundary determined by the :py:class:`~streamer.core.demarcation.EventDemarcation` module.
:param torch.Tensor x: the input feature vector [1, feature_dim] or image [1, 3, H, W]
:param int base_counter: the location of this input in the video for timescale caluation in the :py:class:`~streamer.optimizers.streamer_optimizer.StreamerOptimizer`
"""
# create parent if needed
self.create_parent(self.layer_creator())
self.x_size = base_counter - self.last_base_counter
self.last_base_counter = base_counter
# check if layer is new or recently optimized
if self.context == None:
self.buffer.reset_buffer()
self.buffer.add_input(x, 0.0)
self.predict()
return
# check loss of prediction with input
self.main_objective, demarcation_signal = self.l(self.context, x)
self.objective_ready = True
# check boundary
boundary_demarcation = self.demarcation(demarcation_signal.item())
boundary_buffer = self.args.force_fixed_buffer and self.buffer.counter >= self.buffer.buffer_size
boundary = boundary_demarcation or boundary_buffer
# Tensorboard logging
if self.logger != None:
self.logger.layer(self.layer_num,
boundary,
demarcation_signal.item(),
self.buffer.counter,
self.representation.detach())
# new event created at this level, send current event upwards
if boundary:
# save boundary values for json
self.hierarchy_boundaries.append(base_counter*self.args.snippet_size)
# if above layer exists, send it up
if self.above != None: self.above(self.representation.detach().clone(), base_counter=base_counter)
self.buffer.reset_buffer()
# represent and predict
self.buffer.add_input(x, demarcation_signal.item())
self.predict()