|
import os |
|
import decord |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from torch.utils.data import Dataset |
|
from torchvision import transforms |
|
|
|
|
|
class VideoMAE(torch.utils.data.Dataset): |
|
"""Load your own video classification dataset. |
|
Parameters |
|
---------- |
|
root : str, required. |
|
Path to the root folder storing the dataset. |
|
setting : str, required. |
|
A text file describing the dataset, each line per video sample. |
|
There are three items in each line: (1) video path; (2) video length and (3) video label. |
|
train : bool, default True. |
|
Whether to load the training or validation set. |
|
test_mode : bool, default False. |
|
Whether to perform evaluation on the test set. |
|
Usually there is three-crop or ten-crop evaluation strategy involved. |
|
name_pattern : str, default None. |
|
The naming pattern of the decoded video frames. |
|
For example, img_00012.jpg. |
|
video_ext : str, default 'mp4'. |
|
If video_loader is set to True, please specify the video format accordinly. |
|
is_color : bool, default True. |
|
Whether the loaded image is color or grayscale. |
|
modality : str, default 'rgb'. |
|
Input modalities, we support only rgb video frames for now. |
|
Will add support for rgb difference image and optical flow image later. |
|
num_segments : int, default 1. |
|
Number of segments to evenly divide the video into clips. |
|
A useful technique to obtain global video-level information. |
|
Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016. |
|
num_crop : int, default 1. |
|
Number of crops for each image. default is 1. |
|
Common choices are three crops and ten crops during evaluation. |
|
new_length : int, default 1. |
|
The length of input video clip. Default is a single image, but it can be multiple video frames. |
|
For example, new_length=16 means we will extract a video clip of consecutive 16 frames. |
|
new_step : int, default 1. |
|
Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames. |
|
new_step=2 means we will extract a video clip of every other frame. |
|
temporal_jitter : bool, default False. |
|
Whether to temporally jitter if new_step > 1. |
|
video_loader : bool, default False. |
|
Whether to use video loader to load data. |
|
use_decord : bool, default True. |
|
Whether to use Decord video loader to load data. Otherwise use mmcv video loader. |
|
transform : function, default None. |
|
A function that takes data and label and transforms them. |
|
data_aug : str, default 'v1'. |
|
Different types of data augmentation auto. Supports v1, v2, v3 and v4. |
|
lazy_init : bool, default False. |
|
If set to True, build a dataset instance without loading any dataset. |
|
""" |
|
|
|
def __init__(self, |
|
root, |
|
setting, |
|
train=True, |
|
test_mode=False, |
|
name_pattern='img_%05d.jpg', |
|
video_ext='mp4', |
|
is_color=True, |
|
modality='rgb', |
|
num_segments=1, |
|
num_crop=1, |
|
new_length=1, |
|
new_step=1, |
|
randomize_interframes=False, |
|
transform=None, |
|
temporal_jitter=False, |
|
video_loader=False, |
|
use_decord=False, |
|
lazy_init=False, |
|
is_video_dataset=True): |
|
|
|
super(VideoMAE, self).__init__() |
|
self.root = root |
|
self.setting = setting |
|
self.train = train |
|
self.test_mode = test_mode |
|
self.is_color = is_color |
|
self.modality = modality |
|
self.num_segments = num_segments |
|
self.num_crop = num_crop |
|
self.new_length = new_length |
|
|
|
self.randomize_interframes = randomize_interframes |
|
self._new_step = new_step |
|
|
|
self.temporal_jitter = temporal_jitter |
|
self.name_pattern = name_pattern |
|
self.video_loader = video_loader |
|
self.video_ext = video_ext |
|
self.use_decord = use_decord |
|
self.transform = transform |
|
self.lazy_init = lazy_init |
|
|
|
if (not self.lazy_init) and is_video_dataset: |
|
self.clips = self._make_dataset(root, setting) |
|
if len(self.clips) == 0: |
|
raise (RuntimeError("Found 0 video clips in subfolders of: " + root + "\n" |
|
"Check your data directory (opt.data-dir).")) |
|
|
|
def __getitem__(self, index): |
|
|
|
directory, target = self.clips[index] |
|
|
|
if self.video_loader: |
|
if '.' in directory.split('/')[-1]: |
|
|
|
video_name = directory |
|
else: |
|
|
|
|
|
video_name = '{}.{}'.format(directory, self.video_ext) |
|
|
|
try: |
|
decord_vr = decord.VideoReader(video_name, num_threads=1) |
|
except: |
|
|
|
return (self.__getitem__(index + 1)) |
|
duration = len(decord_vr) |
|
|
|
segment_indices, skip_offsets, new_step, skip_length = self._sample_train_indices(duration) |
|
|
|
images = self._video_TSN_decord_batch_loader(directory, decord_vr, duration, segment_indices, skip_offsets, |
|
new_step, skip_length) |
|
|
|
process_data, mask = self.transform((images, None)) |
|
process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0, |
|
1) |
|
|
|
return (process_data, mask) |
|
|
|
def __len__(self): |
|
return len(self.clips) |
|
|
|
def _make_dataset(self, directory, setting): |
|
if not os.path.exists(setting): |
|
raise (RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting))) |
|
clips = [] |
|
with open(setting) as split_f: |
|
data = split_f.readlines() |
|
for line in data: |
|
line_info = line.split(' ') |
|
|
|
if len(line_info) < 2: |
|
raise (RuntimeError('Video input format is not correct, missing one or more element. %s' % line)) |
|
elif len(line_info) > 2: |
|
line_info = (' '.join(line_info[:-1]), line_info[-1]) |
|
clip_path = os.path.join(line_info[0]) |
|
target = int(line_info[1]) |
|
item = (clip_path, target) |
|
clips.append(item) |
|
|
|
|
|
|
|
return clips |
|
|
|
def _sample_train_indices(self, num_frames): |
|
if self.randomize_interframes is False: |
|
new_step = self._new_step |
|
else: |
|
new_step = np.random.randint(1, self._new_step + 1) |
|
|
|
skip_length = self.new_length * new_step |
|
|
|
average_duration = (num_frames - skip_length + 1) // self.num_segments |
|
if average_duration > 0: |
|
offsets = np.multiply(list(range(self.num_segments)), |
|
average_duration) |
|
offsets = offsets + np.random.randint(average_duration, |
|
size=self.num_segments) |
|
elif num_frames > max(self.num_segments, skip_length): |
|
offsets = np.sort(np.random.randint( |
|
num_frames - skip_length + 1, |
|
size=self.num_segments)) |
|
else: |
|
offsets = np.zeros((self.num_segments,)) |
|
|
|
if self.temporal_jitter: |
|
skip_offsets = np.random.randint( |
|
new_step, size=skip_length // new_step) |
|
else: |
|
skip_offsets = np.zeros( |
|
skip_length // new_step, dtype=int) |
|
return offsets + 1, skip_offsets, new_step, skip_length |
|
|
|
def _video_TSN_decord_batch_loader(self, directory, video_reader, duration, indices, skip_offsets, new_step, |
|
skip_length): |
|
sampled_list = [] |
|
frame_id_list = [] |
|
for seg_ind in indices: |
|
offset = int(seg_ind) |
|
for i, _ in enumerate(range(0, skip_length, new_step)): |
|
if offset + skip_offsets[i] <= duration: |
|
frame_id = offset + skip_offsets[i] - 1 |
|
else: |
|
frame_id = offset - 1 |
|
frame_id_list.append(frame_id) |
|
if offset + new_step < duration: |
|
offset += new_step |
|
try: |
|
video_data = video_reader.get_batch(frame_id_list).asnumpy() |
|
sampled_list = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in |
|
enumerate(frame_id_list)] |
|
except: |
|
raise RuntimeError( |
|
'Error occured in reading frames {} from video {} of duration {}.'.format(frame_id_list, directory, |
|
duration)) |
|
return sampled_list |
|
|
|
|
|
class ContextAndTargetVideoDataset(VideoMAE): |
|
""" |
|
A video dataset whose provided videos consist of (1) a "context" sequence of length Tc |
|
and (2) a "target" sequence Tt. |
|
|
|
These two sequences have the same frame rate (specificiable in real units) but are |
|
separated by a specified gap (which may vary for different examples.) |
|
|
|
The main use case is for training models to predict ahead by some variable amount, |
|
given the context. |
|
""" |
|
|
|
standard_fps = [12, 24, 30, 48, 60, 100] |
|
|
|
def __init__(self, |
|
root, |
|
setting, |
|
train=True, |
|
test_mode=False, |
|
transform=None, |
|
step_units='ms', |
|
new_step=150, |
|
start_frame=0, |
|
context_length=2, |
|
target_length=1, |
|
channels_first=True, |
|
generate_masks=True, |
|
mask_generator=None, |
|
context_target_gap=[400, 600], |
|
normalize_timestamps=True, |
|
default_fps=30, |
|
min_fps=0.1, |
|
seed=0, |
|
*args, |
|
**kwargs): |
|
super(ContextAndTargetVideoDataset, self).__init__( |
|
root=root, |
|
setting=setting, |
|
train=train, |
|
test_mode=test_mode, |
|
transform=transform, |
|
new_length=context_length, |
|
use_decord=True, |
|
lazy_init=False, |
|
video_loader=True, |
|
*args, **kwargs) |
|
|
|
|
|
|
|
self.context_length = self.new_length |
|
self.target_length = target_length |
|
|
|
|
|
self._fps = None |
|
self._min_fps = min_fps |
|
self._default_fps = default_fps |
|
self._step_units = step_units |
|
self.new_step = new_step |
|
|
|
|
|
self._start_frame = start_frame |
|
self.gap = context_target_gap |
|
self.seed = seed |
|
self.rng = np.random.RandomState(seed=seed) |
|
|
|
|
|
|
|
|
|
self._channels_first = channels_first |
|
self._normalize_timestamps = normalize_timestamps |
|
self._generate_masks = generate_masks |
|
self.mask_generator = mask_generator |
|
|
|
|
|
def _get_frames_per_t(self, t): |
|
if self._step_units == 'frames' or (self._step_units is None): |
|
return int(t) |
|
|
|
assert self._fps is not None |
|
t_per_frame = 1 / self._fps |
|
if self._step_units in ['ms', 'milliseconds']: |
|
t_per_frame *= 1000.0 |
|
|
|
return max(int(np.round(t / t_per_frame)), 1) |
|
|
|
@property |
|
def new_step(self): |
|
if self._fps is None: |
|
return None |
|
else: |
|
return self._get_frames_per_t(self._new_step) |
|
|
|
@new_step.setter |
|
def new_step(self, v): |
|
self._new_step = v |
|
|
|
@property |
|
def gap(self): |
|
if self._fps is None: |
|
return [1, 2] |
|
else: |
|
gap = [self._get_frames_per_t(self._gap[0]), |
|
self._get_frames_per_t(self._gap[1])] |
|
gap[1] = max(gap[1], gap[0] + 1) |
|
return gap |
|
|
|
@gap.setter |
|
def gap(self, v): |
|
if v is None: |
|
v = self._new_step |
|
if not isinstance(v, (list, tuple)): |
|
v = [v, v] |
|
self._gap = v |
|
|
|
def _get_video_name(self, directory): |
|
if ''.join(['.', self.video_ext]) in directory.split('/')[-1]: |
|
|
|
video_name = directory |
|
else: |
|
|
|
video_name = '{}.{}'.format(directory, self.video_ext) |
|
return video_name |
|
|
|
def _set_fps(self, reader): |
|
"""click fps to a standard""" |
|
if self._step_units == 'frames' or self._step_units is None: |
|
self._fps = None |
|
else: |
|
self._fps = None |
|
fps = reader.get_avg_fps() |
|
for st in self.standard_fps: |
|
if (int(np.floor(fps)) == st) or (int(np.ceil(fps)) == st): |
|
self._fps = st |
|
if self._fps is None: |
|
self._fps = int(np.round(fps)) |
|
|
|
if self._fps < self._min_fps: |
|
self._fps = self._default_fps |
|
|
|
def _get_step_and_gap(self): |
|
step = self.new_step |
|
if self.randomize_interframes and self.train: |
|
step = self.rng.randint(1, step + 1) |
|
|
|
if self.train: |
|
gap = self.rng.randint(*self.gap) |
|
else: |
|
gap = sum(self.gap) // 2 |
|
return (step, gap) |
|
|
|
def _sample_frames(self): |
|
step, gap = self._get_step_and_gap() |
|
|
|
|
|
|
|
|
|
|
|
self._total_length = self.context_length * step + gap + (self.target_length - 1) * step |
|
if self._total_length > (self._num_frames - self._start_frame): |
|
if self.train: |
|
return None |
|
else: |
|
raise ValueError( |
|
"movie of length %d starting at fr=%d is too long for video of %d frames" % \ |
|
(self._total_length, self._start_frame, self._num_frames)) |
|
|
|
|
|
if self.train: |
|
self.start_frame_now = self.rng.randint( |
|
min(self._start_frame, self._num_frames - self._total_length), |
|
self._num_frames - self._total_length + 1) |
|
else: |
|
self.start_frame_now = min(self._start_frame, self._num_frames - self._total_length) |
|
|
|
frames = [self.start_frame_now + i * step for i in range(self.context_length)] |
|
frames += [frames[-1] + gap + i * step for i in range(self.target_length)] |
|
|
|
|
|
|
|
return frames |
|
|
|
def _decode_frame_images(self, reader, frames): |
|
try: |
|
video_data = reader.get_batch(frames).asnumpy() |
|
video_data = [Image.fromarray(video_data[t, :, :, :]).convert('RGB') |
|
for t, _ in enumerate(frames)] |
|
except: |
|
raise RuntimeError( |
|
"Error occurred in reading frames {} from video {} of duration {}".format( |
|
frames, self.index, self._num_frames)) |
|
return video_data |
|
|
|
def __getitem__(self, index): |
|
|
|
self.index = index |
|
self.directory, target = self.clips[index] |
|
|
|
self.video_name = self._get_video_name(self.directory) |
|
|
|
|
|
try: |
|
decord_vr = decord.VideoReader(self.video_name, num_threads=1) |
|
self._set_fps(decord_vr) |
|
except: |
|
|
|
return (self.__getitem__(index + 1)) |
|
|
|
|
|
self._num_frames = len(decord_vr) |
|
self.frames = self._sample_frames() |
|
if self.frames is None: |
|
print("no movie of length %d for video idx=%d" % (self._total_length, self.index)) |
|
return self.__getitem__(index + 1) |
|
|
|
|
|
image_list = self._decode_frame_images(decord_vr, self.frames) |
|
|
|
|
|
if self.transform is None: |
|
image_tensor = torch.stack([transforms.ToTensor()(img) for img in image_list], 0) |
|
else: |
|
image_tensor = self.transform((image_list, None)) |
|
|
|
image_tensor = image_tensor.view(self.context_length + self.target_length, 3, *image_tensor.shape[-2:]) |
|
|
|
|
|
if self._channels_first: |
|
image_tensor = image_tensor.transpose(0, 1) |
|
|
|
if self._generate_masks and self.mask_generator is not None: |
|
mask = self.mask_generator() |
|
return image_tensor, mask.bool() |
|
else: |
|
return image_tensor |
|
|