mshukor
init
3eb682b
raw
history blame
4.87 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""Data loader."""
import itertools
import numpy as np
import torch
from torch.utils.data._utils.collate import default_collate
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler
from timesformer.datasets.multigrid_helper import ShortCycleBatchSampler
from . import utils as utils
from .build import build_dataset
def detection_collate(batch):
"""
Collate function for detection task. Concatanate bboxes, labels and
metadata from different samples in the first dimension instead of
stacking them to have a batch-size dimension.
Args:
batch (tuple or list): data batch to collate.
Returns:
(tuple): collated detection data batch.
"""
inputs, labels, video_idx, extra_data = zip(*batch)
inputs, video_idx = default_collate(inputs), default_collate(video_idx)
labels = torch.tensor(np.concatenate(labels, axis=0)).float()
collated_extra_data = {}
for key in extra_data[0].keys():
data = [d[key] for d in extra_data]
if key == "boxes" or key == "ori_boxes":
# Append idx info to the bboxes before concatenating them.
bboxes = [
np.concatenate(
[np.full((data[i].shape[0], 1), float(i)), data[i]], axis=1
)
for i in range(len(data))
]
bboxes = np.concatenate(bboxes, axis=0)
collated_extra_data[key] = torch.tensor(bboxes).float()
elif key == "metadata":
collated_extra_data[key] = torch.tensor(
list(itertools.chain(*data))
).view(-1, 2)
else:
collated_extra_data[key] = default_collate(data)
return inputs, labels, video_idx, collated_extra_data
def construct_loader(cfg, split, is_precise_bn=False):
"""
Constructs the data loader for the given dataset.
Args:
cfg (CfgNode): configs. Details can be found in
slowfast/config/defaults.py
split (str): the split of the data loader. Options include `train`,
`val`, and `test`.
"""
assert split in ["train", "val", "test"]
if split in ["train"]:
dataset_name = cfg.TRAIN.DATASET
batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS))
shuffle = True
drop_last = True
elif split in ["val"]:
dataset_name = cfg.TRAIN.DATASET
batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS))
shuffle = False
drop_last = False
elif split in ["test"]:
dataset_name = cfg.TEST.DATASET
batch_size = int(cfg.TEST.BATCH_SIZE / max(1, cfg.NUM_GPUS))
shuffle = False
drop_last = False
# Construct the dataset
dataset = build_dataset(dataset_name, cfg, split)
if cfg.MULTIGRID.SHORT_CYCLE and split in ["train"] and not is_precise_bn:
# Create a sampler for multi-process training
sampler = utils.create_sampler(dataset, shuffle, cfg)
batch_sampler = ShortCycleBatchSampler(
sampler, batch_size=batch_size, drop_last=drop_last, cfg=cfg
)
# Create a loader
loader = torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler,
num_workers=cfg.DATA_LOADER.NUM_WORKERS,
pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
worker_init_fn=utils.loader_worker_init_fn(dataset),
)
else:
# Create a sampler for multi-process training
sampler = utils.create_sampler(dataset, shuffle, cfg)
# Create a loader
loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=(False if sampler else shuffle),
sampler=sampler,
num_workers=cfg.DATA_LOADER.NUM_WORKERS,
pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
drop_last=drop_last,
collate_fn=detection_collate if cfg.DETECTION.ENABLE else None,
worker_init_fn=utils.loader_worker_init_fn(dataset),
)
return loader
def shuffle_dataset(loader, cur_epoch):
""" "
Shuffles the data.
Args:
loader (loader): data loader to perform shuffle.
cur_epoch (int): number of the current epoch.
"""
sampler = (
loader.batch_sampler.sampler
if isinstance(loader.batch_sampler, ShortCycleBatchSampler)
else loader.sampler
)
assert isinstance(
sampler, (RandomSampler, DistributedSampler)
), "Sampler type '{}' not supported".format(type(sampler))
# RandomSampler handles shuffling automatically
if isinstance(sampler, DistributedSampler):
# DistributedSampler shuffles data based on epoch
sampler.set_epoch(cur_epoch)