|
import os |
|
import random |
|
import numpy as np |
|
import torch |
|
import pickle |
|
from PIL import Image |
|
import torch.utils.data as data |
|
import torch.nn.functional as F |
|
from torchvision.transforms import Compose, RandomApply, ToTensor, Normalize, CenterCrop, Lambda, RandomHorizontalFlip, ColorJitter, Resize, RandomCrop |
|
import json |
|
import av |
|
from torchvision.transforms.transforms import RandomResizedCrop |
|
from uniperceiver.tokenization import ClipTokenizer |
|
from uniperceiver.config import configurable |
|
from ..build import DATASETS_REGISTRY |
|
from uniperceiver.functional import dict_as_tensor |
|
from .video_transform import random_short_side_scale_jitter, uniform_crop |
|
import pyarrow as pa |
|
from uniperceiver.utils import comm |
|
import copy |
|
|
|
import io |
|
|
|
__all__ = ["VideoDataSet", "random_clip"] |
|
|
|
|
|
def load_pkl_file(filepath): |
|
return pickle.load(open(filepath, 'rb'), encoding='bytes') if len(filepath) > 0 else None |
|
|
|
|
|
def random_clip(video_frames, sampling_rate, frames_per_clip, fixed_offset=False): |
|
""" |
|
Args: |
|
video_frames (int): total frame number of a video |
|
sampling_rate (int): sampling rate for clip, pick one every k frames |
|
frames_per_clip (int): number of frames of a clip |
|
fixed_offset (bool): used with sample offset to decide the offset value deterministically. |
|
Returns: |
|
list[int]: frame indices (started from zero) |
|
""" |
|
new_sampling_rate = sampling_rate |
|
highest_idx = video_frames - int(new_sampling_rate * (frames_per_clip - 1) + 1) |
|
if highest_idx <= 0: |
|
random_offset = 0 |
|
else: |
|
if fixed_offset: |
|
random_offset = (video_frames - int(new_sampling_rate * frames_per_clip)) // 2 |
|
else: |
|
random_offset = int(np.random.randint(0, highest_idx, 1)) |
|
frame_idx = [int(random_offset + int(i * sampling_rate)) % video_frames for i in range(frames_per_clip)] |
|
frame_idx = [x for x in frame_idx if x < video_frames] |
|
return frame_idx |
|
|
|
@DATASETS_REGISTRY.register() |
|
class VideoDataSet(data.Dataset): |
|
|
|
@configurable |
|
def __init__(self, cfg, stage, root_path, s3_path, list_file, category_file, use_ceph, tcs_conf_path, |
|
tokenizer, tokenizer_name, data_percentage, |
|
frames_per_clip=64, interval=4, num_clips=1, |
|
is_train=True, test_mode=False, num_classes=None, target_fps=30, timesformer_aug=False, minibatches=1): |
|
""" |
|
Args: |
|
root_path (str): the file path to the root of video folder |
|
list_file (str): the file list, each line with folder_path, start_frame, end_frame, label_id |
|
frames_per_clip (int): number of frames per data sample |
|
interval (int): interval between frames |
|
is_train (bool): shuffle the video but keep the causality |
|
test_mode (bool): testing mode, no label |
|
""" |
|
|
|
self.cfg = cfg |
|
self.stage = stage |
|
self.root_path = root_path |
|
self.s3_path = s3_path |
|
self.list_file = list_file |
|
self.category_file = category_file |
|
self.frames_per_clip = frames_per_clip |
|
self.interval = interval |
|
self.num_clips = num_clips |
|
self.is_train = is_train |
|
self.test_mode = test_mode |
|
self.num_classes = num_classes |
|
self.target_fps = target_fps |
|
self.minibatches = minibatches |
|
self.data_percentage = data_percentage |
|
|
|
|
|
self.tokenizer = tokenizer |
|
self.tokenizer_name = tokenizer_name |
|
|
|
self.transform = self._timesformer_transform() if timesformer_aug else self._transform() |
|
|
|
self.use_ceph = use_ceph |
|
if self.use_ceph: |
|
|
|
|
|
self.data_path = self.s3_path |
|
print('debug info for {} {} '.format(self.cfg.DATASETS.DATASET_NAME, self.data_path)) |
|
from uniperceiver.datasets import TCSLoader |
|
|
|
self.tcs_loader = TCSLoader(tcs_conf_path) |
|
else: |
|
self.data_path = self.root_path |
|
|
|
_temp_list =self.load_data(self.cfg) |
|
self.video_list = pa.array(_temp_list) |
|
if comm.is_main_process(): |
|
import sys |
|
print(f"!!! Dataset {self.cfg.DATASETS.DATASET_NAME} with task {self.cfg.DATASETS.TASK_TYPE}:") |
|
print('!!! length of _temp_list: ', len(_temp_list)) |
|
print('!!! size of _temp_list: ', sys.getsizeof(_temp_list)) |
|
print('!!! size of pa database: ', sys.getsizeof(self.video_list)) |
|
del _temp_list |
|
|
|
self.testing_multi_view = self.cfg.DATALOADER.get('MULTI_VEIW', 'v0') |
|
self.temporal_num_view = self.cfg.DATALOADER.get('MULTI_VEIW_NUM', 1) |
|
|
|
self.random_stride = self.cfg.DATALOADER.get('RANDON_STRIDE', False) |
|
|
|
if self.test_mode: |
|
self.frames_per_clip = int(self.frames_per_clip*self.temporal_num_view) |
|
self.interval = int(self.interval/self.temporal_num_view) |
|
|
|
self.task_info = { |
|
'task_type' : self.cfg.DATASETS.TASK_TYPE, |
|
'dataset_name' : self.cfg.DATASETS.DATASET_NAME, |
|
'batch_size' : self.cfg.DATALOADER.TRAIN_BATCH_SIZE if self.stage == 'train' else self.cfg.DATALOADER.TEST_BATCH_SIZE, |
|
'sampling_weight': self.cfg.DATALOADER.SAMPLING_WEIGHT, |
|
|
|
} |
|
|
|
self.target_set = self.cfg.DATASETS.TARGET_SET |
|
|
|
|
|
def _transform(self): |
|
assert False, 'use timesformer augmentation' |
|
transforms = [ |
|
Lambda(lambda frames: torch.stack([ToTensor()(frame.convert("RGB")) for frame in frames])), |
|
] |
|
if self.test_mode: |
|
transforms.extend([ |
|
RandomResizedCrop(224, scale=(0.75, 0.75), ratio=(1.0, 1.0)), |
|
|
|
|
|
]) |
|
else: |
|
transforms.extend([ |
|
|
|
RandomResizedCrop(224, scale=(0.56, 0.95), ratio=(1.0, 1.0)), |
|
RandomHorizontalFlip(), |
|
|
|
RandomApply(torch.nn.ModuleList([ColorJitter(0.4, 0.4, 0.4)]), 0.8), |
|
]) |
|
transforms.append( |
|
|
|
|
|
Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
|
) |
|
return Compose(transforms) |
|
|
|
def _timesformer_transform(self): |
|
transforms = [ |
|
Lambda(lambda frames: torch.stack([ToTensor()(frame.convert("RGB")) for frame in frames])), |
|
] |
|
if self.test_mode: |
|
test_scale = self.cfg.MODEL.IMG_INPUT_SIZE |
|
transforms.extend([ |
|
Lambda(lambda frames: random_short_side_scale_jitter(frames, test_scale, test_scale)[0]), |
|
Lambda(lambda images: torch.stack([uniform_crop(images, test_scale, i)[0] for i in range(3)], 0)) |
|
]) |
|
else: |
|
min_scale = int((256 / 224)*self.cfg.MODEL.IMG_INPUT_SIZE) |
|
max_scale = int((320 / 224)*self.cfg.MODEL.IMG_INPUT_SIZE) |
|
transforms.extend([ |
|
|
|
Lambda(lambda frames: random_short_side_scale_jitter(frames, min_scale, max_scale)[0].unsqueeze(0)), |
|
RandomHorizontalFlip(), |
|
RandomCrop(self.cfg.MODEL.IMG_INPUT_SIZE) |
|
]) |
|
transforms.append( |
|
|
|
|
|
|
|
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
|
) |
|
return Compose(transforms) |
|
|
|
|
|
@classmethod |
|
def from_config(cls, cfg, stage: str = "train"): |
|
if 'SLURM_PROCID' in os.environ: |
|
tcs_conf_path = cfg.DATALOADER.get("TCS_CONF_PATH", "petreloss.config") |
|
else: |
|
|
|
tcs_conf_path = "slurm_tools/petreloss_local.config" |
|
ret = { |
|
"cfg": cfg, |
|
"stage": stage, |
|
"list_file": os.path.join(cfg.DATALOADER.ANNO_FOLDER, cfg.DATALOADER.ANNO_FILE), |
|
"category_file": os.path.join(cfg.DATALOADER.ANNO_FOLDER, "category_mapping.txt"), |
|
"root_path": os.path.join(cfg.DATALOADER.FEATS_FOLDER, "training" if stage == "train" else "validation"), |
|
"s3_path": os.path.join(cfg.DATALOADER.S3_PATH, "training" if stage == "train" else "validation"), |
|
"frames_per_clip": cfg.DATALOADER.FRAMES_PER_CLIP, |
|
"interval": cfg.DATALOADER.STRIDE, |
|
"num_clips": 1 if stage == 'train' else cfg.INFERENCE.NUM_VIEWS, |
|
"is_train": stage == 'train', |
|
"test_mode": stage != 'train', |
|
"num_classes": cfg.MODEL.NUM_CLASSES, |
|
"timesformer_aug": cfg.DATALOADER.TIMESFORMER_AUG, |
|
"minibatches": cfg.DATALOADER.MINI_BATCHES, |
|
"use_ceph": getattr(cfg.DATALOADER, 'USE_CEPH', False), |
|
"tcs_conf_path": tcs_conf_path, |
|
"data_percentage": cfg.DATALOADER.DATA_PERCENTAGE, |
|
} |
|
|
|
|
|
ret['tokenizer'] = ClipTokenizer() |
|
ret['tokenizer_name'] = "clip" |
|
|
|
|
|
return ret |
|
|
|
def load_data(self, cfg): |
|
|
|
|
|
self.cls2idx = dict() |
|
self.idx2cls = dict() |
|
self.class_names = list() |
|
with open(self.category_file, 'r') as f: |
|
for line in f.readlines(): |
|
class_name, idx = line.strip().split('\t') |
|
|
|
class_name = class_name.replace(" ", "_") |
|
self.cls2idx[class_name] = int(idx) |
|
self.idx2cls[int(idx)] = class_name |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
exclude_list = list() |
|
if os.path.exists(os.path.join(os.path.dirname(self.list_file), "exclude_list.txt")): |
|
with open(os.path.join(os.path.dirname(self.list_file), "exclude_list.txt"), 'r') as f: |
|
exclude_list = list(f) |
|
exclude_list = [t.strip() for t in exclude_list] |
|
|
|
video_list = [] |
|
count = 0 |
|
with open(self.list_file) as f: |
|
data_file = json.load(f) |
|
for name, info in data_file['database'].items(): |
|
|
|
|
|
|
|
|
|
video_path = os.path.join(self.data_path, info["annotations"]['label'], name+cfg.DATALOADER.FILE_EXTENSION) |
|
|
|
if os.path.basename(video_path) in exclude_list: |
|
continue |
|
if (self.is_train and info['subset'] == "training") or (not self.is_train and info['subset'] == "validation") : |
|
inst = { |
|
"video_path" : video_path, |
|
"id": name |
|
} |
|
|
|
label = info['annotations']['label'] |
|
inst["target_label"] = label |
|
assert label in self.cls2idx |
|
video_list.append(inst) |
|
|
|
if self.is_train and self.data_percentage < 1.0: |
|
video_dict = dict() |
|
for video in video_list: |
|
if video["target_label"] not in video_dict: |
|
video_dict[video["target_label"]] = list() |
|
video_dict[video["target_label"]].append(video) |
|
new_list = list() |
|
for k, v in video_dict.items(): |
|
new_list.extend(random.sample(v, k=int(len(v)*self.data_percentage)+1)) |
|
video_list = new_list |
|
|
|
num = len(video_list) |
|
print("The number of videos is {}".format(num), flush=True) |
|
assert (num > 0) |
|
return video_list |
|
|
|
def _sample_indices(self, total_frames, fps): |
|
""" |
|
Used for training. |
|
Args: |
|
- record (VideoRecord): |
|
Returns: |
|
list: frame index, index starts from 1. |
|
""" |
|
if self.random_stride: |
|
interval = random.sample([8, 16, 32], k=1)[0] |
|
else: |
|
interval = self.interval |
|
frame_idx = np.asarray(random_clip(total_frames, interval * fps / self.target_fps , self.frames_per_clip)) |
|
return frame_idx |
|
|
|
def _get_val_indices(self, total_frames, fps): |
|
max_frame_idx = total_frames - 1 |
|
sample_pos = max(0, 1 + max_frame_idx - int(self.interval * fps / self.target_fps * self.frames_per_clip)) |
|
start_list = np.linspace(0, sample_pos - 1, num=self.num_clips, dtype=int) |
|
frame_idx = [] |
|
for start_idx in start_list.tolist(): |
|
|
|
ids = [int(idx * self.interval * fps / self.target_fps + start_idx)%total_frames for idx in range(self.frames_per_clip)] |
|
ids = [x for x in ids if x < total_frames] |
|
frame_idx.append(ids) |
|
return frame_idx |
|
|
|
def __getitem__(self, index): |
|
for i_try in range(100): |
|
try: |
|
record = self.video_list[index].as_py() |
|
if self.use_ceph: |
|
container = av.open(io.BytesIO(self.tcs_loader.client.get(record["video_path"]))) |
|
else: |
|
container = av.open(record["video_path"]) |
|
|
|
stream = container.streams.video[0] |
|
total_frames = stream.frames |
|
fps = float(container.streams.video[0].average_rate) |
|
|
|
if total_frames == 0: |
|
|
|
for frame in container.decode(stream): |
|
total_frames += 1 |
|
container.close() |
|
container = av.open(record["video_path"]) |
|
stream = container.streams.video[0] |
|
except Exception as e: |
|
print( |
|
"Failed to load video from {} with error {} ; trial {}".format( |
|
record["video_path"], e, i_try |
|
) |
|
) |
|
|
|
|
|
index = random.randint(0, len(self.video_list) - 1) |
|
continue |
|
|
|
|
|
if self.is_train: |
|
indices = [self._sample_indices(total_frames, fps)] |
|
else: |
|
indices = self._get_val_indices(total_frames, fps) |
|
|
|
all_index = set() |
|
for index in indices: |
|
all_index = all_index.union(set(index)) |
|
|
|
start_index = min(all_index) |
|
num_frames = len(all_index) |
|
|
|
images = dict() |
|
|
|
fetched = 0 |
|
|
|
for frame in container.decode(stream): |
|
if frame.index not in all_index or frame.index in images: |
|
continue |
|
images[frame.index] = frame.to_rgb().to_image() |
|
last = frame.index |
|
fetched += 1 |
|
if fetched == num_frames: |
|
break |
|
|
|
container.close() |
|
|
|
video_data = list() |
|
for ind in indices: |
|
seq = list() |
|
for i in ind: |
|
if i in images: |
|
seq.append(images[i]) |
|
else: |
|
seq.append(images[last]) |
|
video_data.append(self.transform(seq)) |
|
video_data = torch.cat(video_data, dim=0) |
|
|
|
if not self.is_train: |
|
if self.testing_multi_view == 'v1' and self.temporal_num_view > 1: |
|
video_data = video_data.reshape(video_data.shape[0] * self.temporal_num_view, -1, *video_data.shape[-3:]) |
|
num_frames = num_frames // self.temporal_num_view |
|
elif self.testing_multi_view == 'v2' and self.temporal_num_view > 1: |
|
video_data = video_data.reshape(video_data.shape[0], -1, self.temporal_num_view, |
|
*video_data.shape[-3:]).transpose(1, 2).reshape(video_data.shape[0] * self.temporal_num_view, -1, |
|
*video_data.shape[-3:]) |
|
num_frames = num_frames // self.temporal_num_view |
|
|
|
|
|
ret = { |
|
'input_sample':[ |
|
{ |
|
'data': video_data, 'invalid_mask': None, 'modality': 'video', 'data_type': 'input', |
|
'sample_info':{ |
|
'id': record['id'], |
|
'path': record['video_path'], |
|
'num_frames': num_frames, |
|
'num_views': video_data.shape[0], |
|
'cat_along_first_dim': True, |
|
} |
|
} |
|
], |
|
'target_sample': [], |
|
'target_idx': [self.cls2idx[record['target_label']]], |
|
'target_set': copy.deepcopy(self.target_set), |
|
'task_info': copy.deepcopy(self.task_info) |
|
|
|
} |
|
|
|
|
|
return ret |
|
|
|
def __len__(self): |
|
return len(self.video_list) |
|
|