STREAMER#

STREAMER is a predictive learning model that continually trains to improve its future predictions at different timescales. It uses the prediction error to segment events in a hierarchical manner while processing streaming videos.

We provide code snippets in the API documentation on how to instantiate different classes. We also provide simple training and inference scripts to reproduce the results in the NeurIPS’23 paper.

Training Script#

Note

This training script uses commandline arguments as defined in the Arguments.

The helper bash scripts have predefined arguments for gpu and slurm machines for Ego4d and EPIC-KITCHENS datasets.

Note

We use the DDPW library to easily parallelize the code on multiple GPUs or multiple nodes on SLURM.

import torch.nn.functional as F
import torch.distributed as dist
from ddpw import Platform, Wrapper

from streamer.arguments.base_arguments import parser
from streamer.utils.distributed import init_gpu
from streamer.utils.logging import setup_output, JsonLogger

import streamer.datasets as datasets
import streamer.models as models
import streamer.optimizers as optimizers
from tqdm import tqdm



def train_gpu(global_rank, local_rank, args):

    # initialize gpu and tb writer, and return json logger
    init_gpu(global_rank, local_rank, args)

    # get dataloader instance
    loader = datasets.find_dataset_using_name(args)

    # get model instance
    model = models.getModel(args).to(args.device)

    # hierarchical gradient normalization and optimizer
    if args.optimize:
        optimizer = optimizers.getOptimizer(args, model)

    # get logger
    jsonLogger = JsonLogger(is_inference = False,
                        prefix = args.log_prefix,
                        postfix = args.log_postfix,
                        snippet_size = args.snippet_size,
                        json_dir = args.json_dir)


    # main training loop
    if args.main: loader = tqdm(loader)
    for batch_ix, (frames, info) in enumerate(loader):

        if args.normalize_imgs:
            frames = F.normalize(frames, dim=1, p=2)

        # forward and loss calculation
        model(frames.to(args.device))

        # backward and optimization
        if args.optimize: optimizer.step()

        # log video and reset if at the end
        if info[-1] == True:
            json_file = jsonLogger(filepath=info[0][0],
                            duration=info[1].item(),
                            hierarchy=model.extract_rep())
            model.reset_model()

        # save model every args.save_every
        if args.optimize and batch_ix % args.save_every == 0:
            model.save_model()

        # distributed barrier
        if args.world_size>1 and args.optimize:
            dist.barrier()



    # save model every args.save_every
    if args.optimize and batch_ix % args.save_every == 0:
        model.save_model()

    if args.world_size > 1:
        dist.barrier()
        dist.destroy_process_group()

    if args.logger != None: del args.logger


if __name__ == "__main__":

    args = parser().parse_args()
    args = setup_output(args)

    platform = Platform(
                    name=args.p_name,
                    device=args.p_device,
                    partition=args.p_partition,
                    n_nodes=args.p_n_nodes,
                    n_gpus=args.p_n_gpus,
                    n_cpus=args.p_n_cpus,
                    ram=args.p_ram,
                    backend=args.p_backend,
                    console_logs=args.p_logs,
                    verbose=args.p_verbose
                        )

    wrapper = Wrapper(platform=platform)

    # start training
    wrapper.start(train_gpu, args = args)

Inference Script#

Note

Pretrained weights will be released soon..

from streamer.models.inference_model import InferenceModel

model = InferenceModel(checkpoint='to/checkpoint/path/')
result = model(filename='to/video/file/path')