|
|
|
|
|
"""Data loader.""" |
|
|
|
import itertools |
|
import numpy as np |
|
import torch |
|
from torch.utils.data._utils.collate import default_collate |
|
from torch.utils.data.distributed import DistributedSampler |
|
from torch.utils.data.sampler import RandomSampler |
|
|
|
from timesformer.datasets.multigrid_helper import ShortCycleBatchSampler |
|
|
|
from . import utils as utils |
|
from .build import build_dataset |
|
|
|
|
|
def detection_collate(batch): |
|
""" |
|
Collate function for detection task. Concatanate bboxes, labels and |
|
metadata from different samples in the first dimension instead of |
|
stacking them to have a batch-size dimension. |
|
Args: |
|
batch (tuple or list): data batch to collate. |
|
Returns: |
|
(tuple): collated detection data batch. |
|
""" |
|
inputs, labels, video_idx, extra_data = zip(*batch) |
|
inputs, video_idx = default_collate(inputs), default_collate(video_idx) |
|
labels = torch.tensor(np.concatenate(labels, axis=0)).float() |
|
|
|
collated_extra_data = {} |
|
for key in extra_data[0].keys(): |
|
data = [d[key] for d in extra_data] |
|
if key == "boxes" or key == "ori_boxes": |
|
|
|
bboxes = [ |
|
np.concatenate( |
|
[np.full((data[i].shape[0], 1), float(i)), data[i]], axis=1 |
|
) |
|
for i in range(len(data)) |
|
] |
|
bboxes = np.concatenate(bboxes, axis=0) |
|
collated_extra_data[key] = torch.tensor(bboxes).float() |
|
elif key == "metadata": |
|
collated_extra_data[key] = torch.tensor( |
|
list(itertools.chain(*data)) |
|
).view(-1, 2) |
|
else: |
|
collated_extra_data[key] = default_collate(data) |
|
|
|
return inputs, labels, video_idx, collated_extra_data |
|
|
|
|
|
def construct_loader(cfg, split, is_precise_bn=False): |
|
""" |
|
Constructs the data loader for the given dataset. |
|
Args: |
|
cfg (CfgNode): configs. Details can be found in |
|
slowfast/config/defaults.py |
|
split (str): the split of the data loader. Options include `train`, |
|
`val`, and `test`. |
|
""" |
|
assert split in ["train", "val", "test"] |
|
if split in ["train"]: |
|
dataset_name = cfg.TRAIN.DATASET |
|
batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS)) |
|
shuffle = True |
|
drop_last = True |
|
elif split in ["val"]: |
|
dataset_name = cfg.TRAIN.DATASET |
|
batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS)) |
|
shuffle = False |
|
drop_last = False |
|
elif split in ["test"]: |
|
dataset_name = cfg.TEST.DATASET |
|
batch_size = int(cfg.TEST.BATCH_SIZE / max(1, cfg.NUM_GPUS)) |
|
shuffle = False |
|
drop_last = False |
|
|
|
|
|
dataset = build_dataset(dataset_name, cfg, split) |
|
|
|
if cfg.MULTIGRID.SHORT_CYCLE and split in ["train"] and not is_precise_bn: |
|
|
|
sampler = utils.create_sampler(dataset, shuffle, cfg) |
|
batch_sampler = ShortCycleBatchSampler( |
|
sampler, batch_size=batch_size, drop_last=drop_last, cfg=cfg |
|
) |
|
|
|
loader = torch.utils.data.DataLoader( |
|
dataset, |
|
batch_sampler=batch_sampler, |
|
num_workers=cfg.DATA_LOADER.NUM_WORKERS, |
|
pin_memory=cfg.DATA_LOADER.PIN_MEMORY, |
|
worker_init_fn=utils.loader_worker_init_fn(dataset), |
|
) |
|
else: |
|
|
|
sampler = utils.create_sampler(dataset, shuffle, cfg) |
|
|
|
loader = torch.utils.data.DataLoader( |
|
dataset, |
|
batch_size=batch_size, |
|
shuffle=(False if sampler else shuffle), |
|
sampler=sampler, |
|
num_workers=cfg.DATA_LOADER.NUM_WORKERS, |
|
pin_memory=cfg.DATA_LOADER.PIN_MEMORY, |
|
drop_last=drop_last, |
|
collate_fn=detection_collate if cfg.DETECTION.ENABLE else None, |
|
worker_init_fn=utils.loader_worker_init_fn(dataset), |
|
) |
|
return loader |
|
|
|
|
|
def shuffle_dataset(loader, cur_epoch): |
|
""" " |
|
Shuffles the data. |
|
Args: |
|
loader (loader): data loader to perform shuffle. |
|
cur_epoch (int): number of the current epoch. |
|
""" |
|
sampler = ( |
|
loader.batch_sampler.sampler |
|
if isinstance(loader.batch_sampler, ShortCycleBatchSampler) |
|
else loader.sampler |
|
) |
|
assert isinstance( |
|
sampler, (RandomSampler, DistributedSampler) |
|
), "Sampler type '{}' not supported".format(type(sampler)) |
|
|
|
if isinstance(sampler, DistributedSampler): |
|
|
|
sampler.set_epoch(cur_epoch) |
|
|