|
|
|
""" |
|
@author: Yehao Li, Jingwen Chen |
|
@contact: [email protected], [email protected] |
|
""" |
|
import os |
|
import copy |
|
import pickle |
|
import random |
|
import numpy as np |
|
import torch |
|
from uniperceiver.config import configurable |
|
from uniperceiver.functional import read_np, dict_as_tensor |
|
from ..build import DATASETS_REGISTRY |
|
from uniperceiver.tokenization import ClipTokenizer |
|
from torchvision.transforms import Compose, RandomApply, ToTensor, Normalize, CenterCrop, Lambda, RandomHorizontalFlip, ColorJitter, Resize, RandomCrop |
|
from .video_transform import random_short_side_scale_jitter, uniform_crop |
|
import json |
|
from io import BytesIO |
|
import av |
|
from .video_raw import VideoDataSet |
|
import io |
|
from collections import defaultdict |
|
|
|
import pyarrow as pa |
|
from uniperceiver.utils import comm |
|
import copy |
|
|
|
__all__ = ["MSRVTTDataset"] |
|
|
|
def random_clip(video_frames, sampling_rate, frames_per_clip, fixed_offset=False): |
|
""" |
|
Args: |
|
video_frames (int): total frame number of a video |
|
sampling_rate (int): sampling rate for clip, pick one every k frames |
|
frames_per_clip (int): number of frames of a clip |
|
fixed_offset (bool): used with sample offset to decide the offset value deterministically. |
|
Returns: |
|
list[int]: frame indices (started from zero) |
|
""" |
|
new_sampling_rate = sampling_rate |
|
highest_idx = video_frames - int(new_sampling_rate * frames_per_clip) |
|
if highest_idx <= 0: |
|
random_offset = 0 |
|
else: |
|
if fixed_offset: |
|
random_offset = (video_frames - int(new_sampling_rate * frames_per_clip)) // 2 |
|
else: |
|
random_offset = int(np.random.randint(0, highest_idx, 1)) |
|
frame_idx = [int(random_offset + int(i * sampling_rate)) % video_frames for i in range(frames_per_clip)] |
|
frame_idx = [x for x in frame_idx if x < video_frames] |
|
return frame_idx |
|
|
|
|
|
@DATASETS_REGISTRY.register() |
|
class MSRVTTDataset(VideoDataSet): |
|
@configurable |
|
def __init__( |
|
self, |
|
stage: str, |
|
anno_file: str, |
|
seq_per_img: int, |
|
max_feat_num: int, |
|
max_seq_len: int, |
|
feats_folder: str, |
|
tokenizer, |
|
tokenizer_name, |
|
use_ceph: bool, |
|
tcs_conf_path, |
|
frames_per_clip, interval, num_clips, timesformer_aug, |
|
task_type, |
|
data_percentage, |
|
target_fps=30, |
|
random_mask=False, |
|
cfg=None, |
|
): |
|
self.cfg = cfg |
|
self.stage = stage |
|
self.anno_file = anno_file |
|
self.seq_per_img = seq_per_img |
|
self.max_feat_num = max_feat_num |
|
self.feats_folder = feats_folder |
|
self.max_seq_len = max_seq_len |
|
self.task_type = task_type |
|
|
|
self.initialized = False |
|
|
|
|
|
self.tokenizer = tokenizer |
|
self.tokenizer_name = tokenizer_name |
|
self.use_clip_tokenizer = self.tokenizer_name == 'clip' |
|
|
|
self.idx2name = dict() |
|
self.name2idx = dict() |
|
|
|
self.use_ceph = use_ceph |
|
if isinstance(self.anno_file, list): |
|
self.cache_dir = os.path.join(os.path.dirname(self.anno_file[0]), 'cache') |
|
else: |
|
self.cache_dir = os.path.join(os.path.dirname(self.anno_file), 'cache') |
|
self.frames_per_clip = frames_per_clip |
|
self.interval = interval |
|
|
|
|
|
|
|
self.random_stride = self.cfg.DATALOADER.get('RANDON_STRIDE', False) |
|
|
|
self.num_clips = num_clips |
|
self.is_train = stage == 'train' |
|
self.test_mode = stage != 'train' |
|
self.transform = self._timesformer_transform() if timesformer_aug else self._transform() |
|
self.target_fps = target_fps |
|
self.data_percentage = data_percentage |
|
|
|
if self.use_ceph: |
|
self.feats_folder = 's3://msrvtt/videos/' |
|
if isinstance(self.anno_file, list): |
|
self.anno_file = [os.path.join('s3://msrvtt/annotations/', os.path.basename(anno_file)) for anno_file in self.anno_file] |
|
else: |
|
self.anno_file = os.path.join('s3://msrvtt/annotations/', os.path.basename(self.anno_file)) |
|
print('debug info for msrvtt pretrain: {} '.format(self.feats_folder)) |
|
from uniperceiver.datasets import TCSLoader |
|
if 'SLURM_PROCID' in os.environ: |
|
self.tcs_loader = TCSLoader(tcs_conf_path) |
|
else: |
|
self.tcs_loader = TCSLoader('slurm_tools/petreloss_local.config') |
|
else: |
|
|
|
self.feats_folder = feats_folder |
|
|
|
if self.use_ceph: |
|
if isinstance(self.anno_file, list): |
|
videoinfo = list() |
|
for anno_file in self.anno_file: |
|
videoinfo.extend(json.load(BytesIO(self.tcs_loader.client.get(anno_file)))["images"]) |
|
else: |
|
videoinfo = json.load(BytesIO(self.tcs_loader.client.get(self.anno_file)))["images"] |
|
else: |
|
if isinstance(self.anno_file, list): |
|
videoinfo = list() |
|
for anno_file in self.anno_file: |
|
videoinfo.extend(json.load(open(anno_file))["images"]) |
|
else: |
|
videoinfo = json.load(open(self.anno_file))["images"] |
|
for vinfo in videoinfo: |
|
self.idx2name[vinfo['id']] = vinfo['file_name'] |
|
self.name2idx[vinfo['file_name']] = vinfo['id'] |
|
self.random_mask = random_mask |
|
pass |
|
|
|
_temp_list =self.load_data(self.cfg) |
|
self.video_list = pa.array(_temp_list) |
|
if comm.is_main_process(): |
|
import sys |
|
print(f"!!! Dataset {self.cfg.DATASETS.DATASET_NAME} with task {self.cfg.DATASETS.TASK_TYPE}:") |
|
print('!!! length of _temp_list: ', len(_temp_list)) |
|
print('!!! size of _temp_list: ', sys.getsizeof(_temp_list)) |
|
print('!!! size of pa database: ', sys.getsizeof(self.video_list)) |
|
del _temp_list |
|
|
|
self.task_info = { |
|
'task_type' : self.cfg.DATASETS.TASK_TYPE, |
|
'dataset_name' : self.cfg.DATASETS.DATASET_NAME, |
|
'batch_size' : self.cfg.DATALOADER.TRAIN_BATCH_SIZE if self.stage == 'train' else self.cfg.DATALOADER.TEST_BATCH_SIZE, |
|
'sampling_weight': self.cfg.DATALOADER.SAMPLING_WEIGHT |
|
} |
|
|
|
self.target_set = self.cfg.DATASETS.TARGET_SET |
|
|
|
|
|
@classmethod |
|
def from_config(cls, cfg, stage: str = "train"): |
|
if stage == "train": |
|
ann_file = os.path.join(cfg.DATALOADER.ANNO_FOLDER, "caption_msrvtt_1k_trainval_cocostyle.json") |
|
else: |
|
assert stage == "test" |
|
ann_file = os.path.join(cfg.DATALOADER.ANNO_FOLDER, "caption_msrvtt_1k_test_cocostyle.json") |
|
feat_path = os.path.join(cfg.DATALOADER.FEATS_FOLDER, "MSRVTT_ResNet152_{}.hdf5".format(stage)) |
|
|
|
if 'SLURM_PROCID' in os.environ: |
|
tcs_conf_path = cfg.DATALOADER.get("TCS_CONF_PATH", "petreloss.config") |
|
else: |
|
|
|
tcs_conf_path = "petreloss_local.config" |
|
|
|
ret = { |
|
"stage": stage, |
|
"anno_file": ann_file, |
|
"seq_per_img": cfg.DATALOADER.SEQ_PER_SAMPLE, |
|
"max_feat_num": cfg.DATALOADER.MAX_FEAT_NUM, |
|
"feats_folder": feat_path, |
|
"max_seq_len": cfg.MODEL.MAX_SEQ_LEN, |
|
"use_ceph": getattr(cfg.DATALOADER, 'USE_CEPH', False), |
|
"tcs_conf_path": tcs_conf_path, |
|
'task_type': cfg.DATASETS.TASK_TYPE, |
|
"frames_per_clip": cfg.DATALOADER.FRAMES_PER_CLIP, |
|
"interval": cfg.DATALOADER.STRIDE, |
|
"num_clips": 1 if stage == 'train' else cfg.INFERENCE.NUM_VIEWS, |
|
"timesformer_aug": cfg.DATALOADER.TIMESFORMER_AUG, |
|
"data_percentage": cfg.DATALOADER.DATA_PERCENTAGE, |
|
"cfg": cfg, |
|
} |
|
if getattr(cfg.INFERENCE, "VOCAB", None) == 'CLIP': |
|
ret['tokenizer'] = ClipTokenizer() |
|
ret['tokenizer_name'] = "clip" |
|
else: |
|
raise NotImplementedError |
|
return ret |
|
|
|
def load_data(self, cfg): |
|
if self.stage == "train": |
|
total_datalist = list() |
|
for i, stage in enumerate(["train", "val"]): |
|
cache_path = os.path.join( |
|
self.cache_dir, |
|
"msrvtt_raw_caption_retrieval_%s_%s_%d.pkl" % (self.tokenizer_name, stage, self.max_seq_len) |
|
) |
|
if not os.path.exists(os.path.dirname(cache_path)): |
|
os.makedirs(os.path.dirname(cache_path)) |
|
if not os.path.exists(cache_path): |
|
datalist = self.load_raw_data(cfg, self.anno_file[i]) |
|
pickle.dump(datalist, open(cache_path, "wb")) |
|
datalist = pickle.load(open(cache_path, "rb")) |
|
if isinstance(datalist[0]['caption'], list): |
|
new_datalist = list() |
|
for data in datalist: |
|
if isinstance(data['caption'], str): |
|
new_datalist.append(data) |
|
else: |
|
video_id = data['video_id'] |
|
for caption in data['caption']: |
|
new_datalist.append({ |
|
"video_id": video_id, |
|
"caption": caption, |
|
}) |
|
datalist = new_datalist |
|
total_datalist.extend(datalist) |
|
|
|
if self.data_percentage < 1.0 and self.stage == 'train': |
|
datalist = random.sample(total_datalist, k = int(self.data_percentage* len(total_datalist) ) ) |
|
total_datalist = datalist |
|
|
|
else: |
|
assert self.stage == "test" |
|
cache_path = os.path.join( |
|
self.cache_dir, |
|
"msrvtt_raw_caption_retrieval_%s_%s_%d.pkl" % (self.tokenizer_name, self.stage, self.max_seq_len) |
|
) |
|
if not os.path.exists(os.path.dirname(cache_path)): |
|
os.makedirs(os.path.dirname(cache_path)) |
|
if not os.path.exists(cache_path): |
|
datalist = self.load_raw_data(cfg, self.anno_file) |
|
pickle.dump(datalist, open(cache_path, "wb")) |
|
datalist = pickle.load(open(cache_path, "rb")) |
|
total_datalist = datalist |
|
return total_datalist |
|
|
|
|
|
def load_raw_data(self, cfg, anno_file): |
|
datalist = [] |
|
if self.stage == 'train': |
|
if self.use_ceph: |
|
annoinfo = json.load(BytesIO(self.tcs_loader.client.get(anno_file))) |
|
else: |
|
annoinfo = json.load(open(anno_file)) |
|
captions_train = sorted( annoinfo['annotations'], key=lambda x: x['id']) |
|
for data in captions_train: |
|
datalist.append( |
|
{ |
|
'video_id': data['image_id'], |
|
'caption': data['caption'] |
|
} |
|
) |
|
|
|
else: |
|
if self.use_ceph: |
|
annoinfo = json.load(BytesIO(self.tcs_loader.client.get(self.anno_file))) |
|
else: |
|
annoinfo = json.load(open(self.anno_file)) |
|
captions_train = sorted( annoinfo['annotations'], key=lambda x: x['id']) |
|
video2caps = defaultdict(list) |
|
for data in captions_train: |
|
video2caps[data['image_id']].append(data['caption']) |
|
|
|
for videoid, caps in video2caps.items(): |
|
datalist.append( |
|
{ |
|
'video_id': videoid, |
|
'caption': caps |
|
} |
|
) |
|
return datalist |
|
|
|
def _timesformer_transform(self): |
|
transforms = [ |
|
Lambda(lambda frames: torch.stack([ToTensor()(frame.convert("RGB")) for frame in frames])), |
|
] |
|
if self.test_mode: |
|
test_scale = self.cfg.MODEL.IMG_INPUT_SIZE |
|
transforms.extend([ |
|
Lambda(lambda frames: random_short_side_scale_jitter( |
|
frames, test_scale, test_scale)[0]), |
|
CenterCrop(test_scale), |
|
|
|
]) |
|
else: |
|
min_scale = int((256 / 224)*self.cfg.MODEL.IMG_INPUT_SIZE) |
|
max_scale = int((320 / 224)*self.cfg.MODEL.IMG_INPUT_SIZE) |
|
|
|
transforms.extend([ |
|
Lambda(lambda frames: random_short_side_scale_jitter(frames, min_scale, max_scale)[0].unsqueeze(0)), |
|
RandomHorizontalFlip(), |
|
RandomCrop(self.cfg.MODEL.IMG_INPUT_SIZE) |
|
]) |
|
transforms.append( |
|
|
|
|
|
|
|
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
|
) |
|
return Compose(transforms) |
|
|
|
def _sample_frame(self, atten_feats): |
|
interval = atten_feats.shape[0] / self.max_feat_num |
|
selected_indexes = [int(i * interval) for i in range(self.max_feat_num)] |
|
selected_frames = atten_feats[selected_indexes, :] |
|
return selected_frames |
|
|
|
def random_word_wwm(self, tokens): |
|
output_tokens = [] |
|
output_label = [] |
|
|
|
for i, token in enumerate(tokens): |
|
if self.use_clip_tokenizer: |
|
sub_tokens = self.tokenizer.encode_basic_tokenized_token(token) |
|
else: |
|
sub_tokens = self.tokenizer.wordpiece_tokenizer.tokenize(token) |
|
prob = random.random() |
|
|
|
if prob < 0.15: |
|
prob /= 0.15 |
|
|
|
|
|
if prob < 0.8: |
|
for sub_token in sub_tokens: |
|
if self.use_clip_tokenizer: |
|
output_tokens.append(self.tokenizer.encoder["<|spe|>"]) |
|
else: |
|
output_tokens.append("[MASK]") |
|
|
|
elif prob < 0.9: |
|
for sub_token in sub_tokens: |
|
if self.use_clip_tokenizer: |
|
output_tokens.append(random.choice(list(range(len(self.tokenizer.encoder))))) |
|
else: |
|
output_tokens.append(random.choice(list(self.tokenizer.vocab.keys()))) |
|
|
|
else: |
|
for sub_token in sub_tokens: |
|
output_tokens.append(sub_token) |
|
|
|
|
|
for sub_token in sub_tokens: |
|
if self.use_clip_tokenizer: |
|
output_label.append(sub_token) |
|
else: |
|
try: |
|
output_label.append(self.tokenizer.vocab[sub_token]) |
|
except KeyError: |
|
|
|
output_label.append(self.tokenizer.vocab["[UNK]"]) |
|
else: |
|
for sub_token in sub_tokens: |
|
|
|
output_tokens.append(sub_token) |
|
output_label.append(-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return output_tokens, output_label |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
for i_try in range(100): |
|
|
|
record = self.video_list[idx].as_py() |
|
record = copy.deepcopy(record) |
|
video_id = record['video_id'] |
|
|
|
|
|
video_path = os.path.join(self.feats_folder, self.idx2name[video_id] + '.mp4') |
|
if self.use_ceph: |
|
container = av.open(io.BytesIO(self.tcs_loader.client.get(video_path))) |
|
else: |
|
container = av.open(video_path) |
|
|
|
|
|
|
|
stream = container.streams.video[0] |
|
total_frames = stream.frames |
|
fps = float(container.streams.video[0].average_rate) |
|
|
|
if total_frames == 0: |
|
|
|
for frame in container.decode(stream): |
|
total_frames += 1 |
|
container.close() |
|
container = av.open(video_path) |
|
stream = container.streams.video[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.stage=='train': |
|
indices = [self._sample_indices(total_frames, fps)] |
|
else: |
|
indices = self._get_val_indices(total_frames, fps) |
|
|
|
all_index = set() |
|
for index in indices: |
|
all_index = all_index.union(set(index)) |
|
|
|
start_index = min(all_index) |
|
num_frames = len(all_index) |
|
|
|
images = dict() |
|
|
|
fetched = 0 |
|
|
|
for frame in container.decode(stream): |
|
if frame.index not in all_index or frame.index in images: |
|
continue |
|
images[frame.index] = frame.to_rgb().to_image() |
|
last = frame.index |
|
fetched += 1 |
|
if fetched == num_frames: |
|
break |
|
|
|
container.close() |
|
|
|
video_data = list() |
|
for ind in indices: |
|
seq = list() |
|
for i in ind: |
|
if i in images: |
|
seq.append(images[i]) |
|
else: |
|
seq.append(images[last]) |
|
video_data.append(self.transform(seq)) |
|
video_data = torch.cat(video_data, dim=0) |
|
|
|
if video_data.dim() == 4: |
|
video_data.unsqueeze_(0) |
|
|
|
ret = { |
|
'input_sample':[ |
|
{ |
|
'data': video_data, 'invalid_mask': None, 'modality': 'video', 'data_type': 'input', |
|
'sample_info':{ |
|
'id': video_id, |
|
'path': video_path, |
|
'num_views':num_frames, |
|
'cat_along_first_dim': True, |
|
} |
|
} |
|
], |
|
'target_sample': [], |
|
} |
|
|
|
if self.stage == 'train' and record['caption'] is not None: |
|
caption = record['caption'] |
|
caption = caption + " <|endoftext|>" |
|
|
|
if self.task_type == 'video_mlm': |
|
u_mask_type = 1 |
|
elif self.task_type == 'video_caption': |
|
u_mask_type = 0 |
|
|
|
if self.task_type=='video_caption' or self.task_type =='video_mlm': |
|
if u_mask_type == 1: |
|
caption_tokens = self.tokenizer.basic_tokenize(caption) |
|
caption_tokens, mlm_labels = self.random_word_wwm(caption_tokens) |
|
else: |
|
|
|
caption_tokens = self.tokenizer.encode(caption) |
|
mlm_labels = self.tokenizer.encode("<|spe|>")*len(caption_tokens) |
|
|
|
else: |
|
caption_tokens = self.tokenizer.encode(caption) |
|
|
|
|
|
if len(caption_tokens) > self.max_seq_len: |
|
|
|
text_len_keep = self.max_seq_len |
|
caption_tokens = caption_tokens[:(text_len_keep - 1)] + [caption_tokens[-1]] |
|
if self.task_type == 'video_caption' or self.task_type == 'video_mlm': |
|
mlm_labels = mlm_labels[:(text_len_keep - 1)] + [mlm_labels[-1]] |
|
|
|
ret = { |
|
'input_sample': [{ |
|
'data': video_data, 'invalid_mask': None, 'modality': 'video', 'data_type': 'input', |
|
'sample_info':{ |
|
'id': video_id, |
|
'path': video_path, |
|
'num_views':num_frames, |
|
'cat_along_first_dim': True, |
|
} |
|
}] |
|
} |
|
|
|
if self.task_type == 'video_caption': |
|
|
|
source = np.array(caption_tokens, dtype=np.int64) |
|
source2 = np.array(mlm_labels, dtype=np.int64) |
|
ret['input_sample'].append({ |
|
'data': [source, source2], |
|
'invalid_mask': None, |
|
'modality': 'text', |
|
'data_type': 'input', |
|
'sample_info': { |
|
'text_spe_cat': True, |
|
} |
|
}) |
|
ret.update({ |
|
'target_sample': [], |
|
'target_idx' : [np.array(caption_tokens, dtype=np.int64)], |
|
'target_set' : copy.deepcopy(self.target_set), |
|
'task_info' : copy.deepcopy(self.task_info) |
|
}) |
|
|
|
elif self.task_type == 'video_mlm': |
|
|
|
raise NotImplementedError('no needed for masked language modeling when given video now.') |
|
|
|
|
|
elif self.task_type == 'video_retrieval': |
|
ret.update({ |
|
'target_sample': [{ |
|
'data' : [np.array(caption_tokens, dtype=np.int64)], |
|
'modality' : 'text', |
|
'data_type' : 'target', |
|
'invalid_mask': None, |
|
'sample_info' : {} |
|
}], |
|
'target_idx' : [], |
|
'target_set' : [], |
|
'task_info' : copy.deepcopy(self.task_info) |
|
}) |
|
else: |
|
raise NotImplementedError |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
|
|
dict_as_tensor(ret) |
|
return ret |
|
|