# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """Data loader.""" import itertools import numpy as np import torch from torch.utils.data._utils.collate import default_collate from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import RandomSampler from timesformer.datasets.multigrid_helper import ShortCycleBatchSampler from . import utils as utils from .build import build_dataset def detection_collate(batch): """ Collate function for detection task. Concatanate bboxes, labels and metadata from different samples in the first dimension instead of stacking them to have a batch-size dimension. Args: batch (tuple or list): data batch to collate. Returns: (tuple): collated detection data batch. """ inputs, labels, video_idx, extra_data = zip(*batch) inputs, video_idx = default_collate(inputs), default_collate(video_idx) labels = torch.tensor(np.concatenate(labels, axis=0)).float() collated_extra_data = {} for key in extra_data[0].keys(): data = [d[key] for d in extra_data] if key == "boxes" or key == "ori_boxes": # Append idx info to the bboxes before concatenating them. bboxes = [ np.concatenate( [np.full((data[i].shape[0], 1), float(i)), data[i]], axis=1 ) for i in range(len(data)) ] bboxes = np.concatenate(bboxes, axis=0) collated_extra_data[key] = torch.tensor(bboxes).float() elif key == "metadata": collated_extra_data[key] = torch.tensor( list(itertools.chain(*data)) ).view(-1, 2) else: collated_extra_data[key] = default_collate(data) return inputs, labels, video_idx, collated_extra_data def construct_loader(cfg, split, is_precise_bn=False): """ Constructs the data loader for the given dataset. Args: cfg (CfgNode): configs. Details can be found in slowfast/config/defaults.py split (str): the split of the data loader. Options include `train`, `val`, and `test`. """ assert split in ["train", "val", "test"] if split in ["train"]: dataset_name = cfg.TRAIN.DATASET batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS)) shuffle = True drop_last = True elif split in ["val"]: dataset_name = cfg.TRAIN.DATASET batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS)) shuffle = False drop_last = False elif split in ["test"]: dataset_name = cfg.TEST.DATASET batch_size = int(cfg.TEST.BATCH_SIZE / max(1, cfg.NUM_GPUS)) shuffle = False drop_last = False # Construct the dataset dataset = build_dataset(dataset_name, cfg, split) if cfg.MULTIGRID.SHORT_CYCLE and split in ["train"] and not is_precise_bn: # Create a sampler for multi-process training sampler = utils.create_sampler(dataset, shuffle, cfg) batch_sampler = ShortCycleBatchSampler( sampler, batch_size=batch_size, drop_last=drop_last, cfg=cfg ) # Create a loader loader = torch.utils.data.DataLoader( dataset, batch_sampler=batch_sampler, num_workers=cfg.DATA_LOADER.NUM_WORKERS, pin_memory=cfg.DATA_LOADER.PIN_MEMORY, worker_init_fn=utils.loader_worker_init_fn(dataset), ) else: # Create a sampler for multi-process training sampler = utils.create_sampler(dataset, shuffle, cfg) # Create a loader loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=(False if sampler else shuffle), sampler=sampler, num_workers=cfg.DATA_LOADER.NUM_WORKERS, pin_memory=cfg.DATA_LOADER.PIN_MEMORY, drop_last=drop_last, collate_fn=detection_collate if cfg.DETECTION.ENABLE else None, worker_init_fn=utils.loader_worker_init_fn(dataset), ) return loader def shuffle_dataset(loader, cur_epoch): """ " Shuffles the data. Args: loader (loader): data loader to perform shuffle. cur_epoch (int): number of the current epoch. """ sampler = ( loader.batch_sampler.sampler if isinstance(loader.batch_sampler, ShortCycleBatchSampler) else loader.sampler ) assert isinstance( sampler, (RandomSampler, DistributedSampler) ), "Sampler type '{}' not supported".format(type(sampler)) # RandomSampler handles shuffling automatically if isinstance(sampler, DistributedSampler): # DistributedSampler shuffles data based on epoch sampler.set_epoch(cur_epoch)