STREAMER#
STREAMER is a predictive learning model that continually trains to improve its future predictions at different timescales. It uses the prediction error to segment events in a hierarchical manner while processing streaming videos.
We provide code snippets in the API documentation on how to instantiate different classes. We also provide simple training and inference scripts to reproduce the results in the NeurIPS’23 paper.
Training Script#
Note
This training script uses commandline arguments as defined in the Arguments.
The helper bash scripts have predefined arguments for gpu and slurm machines for Ego4d and EPIC-KITCHENS datasets.
Note
We use the DDPW library to easily parallelize the code on multiple GPUs or multiple nodes on SLURM.
import torch.nn.functional as F
import torch.distributed as dist
from ddpw import Platform, Wrapper
from streamer.arguments.base_arguments import parser
from streamer.utils.distributed import init_gpu
from streamer.utils.logging import setup_output, JsonLogger
import streamer.datasets as datasets
import streamer.models as models
import streamer.optimizers as optimizers
from tqdm import tqdm
def train_gpu(global_rank, local_rank, args):
# initialize gpu and tb writer, and return json logger
init_gpu(global_rank, local_rank, args)
# get dataloader instance
loader = datasets.find_dataset_using_name(args)
# get model instance
model = models.getModel(args).to(args.device)
# hierarchical gradient normalization and optimizer
if args.optimize:
optimizer = optimizers.getOptimizer(args, model)
# get logger
jsonLogger = JsonLogger(is_inference = False,
prefix = args.log_prefix,
postfix = args.log_postfix,
snippet_size = args.snippet_size,
json_dir = args.json_dir)
# main training loop
if args.main: loader = tqdm(loader)
for batch_ix, (frames, info) in enumerate(loader):
if args.normalize_imgs:
frames = F.normalize(frames, dim=1, p=2)
# forward and loss calculation
model(frames.to(args.device))
# backward and optimization
if args.optimize: optimizer.step()
# log video and reset if at the end
if info[-1] == True:
json_file = jsonLogger(filepath=info[0][0],
duration=info[1].item(),
hierarchy=model.extract_rep())
model.reset_model()
# save model every args.save_every
if args.optimize and batch_ix % args.save_every == 0:
model.save_model()
# distributed barrier
if args.world_size>1 and args.optimize:
dist.barrier()
# save model every args.save_every
if args.optimize and batch_ix % args.save_every == 0:
model.save_model()
if args.world_size > 1:
dist.barrier()
dist.destroy_process_group()
if args.logger != None: del args.logger
if __name__ == "__main__":
args = parser().parse_args()
args = setup_output(args)
platform = Platform(
name=args.p_name,
device=args.p_device,
partition=args.p_partition,
n_nodes=args.p_n_nodes,
n_gpus=args.p_n_gpus,
n_cpus=args.p_n_cpus,
ram=args.p_ram,
backend=args.p_backend,
console_logs=args.p_logs,
verbose=args.p_verbose
)
wrapper = Wrapper(platform=platform)
# start training
wrapper.start(train_gpu, args = args)
Inference Script#
Note
Pretrained weights will be released soon..
from streamer.models.inference_model import InferenceModel
model = InferenceModel(checkpoint='to/checkpoint/path/')
result = model(filename='to/video/file/path')