model1 / llava /datasets /lk_video_dataset.py
multitensor's picture
Upload folder using huggingface_hub
bbfa6f6 verified
from llava.datasets.builder import DATASETS
from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from llava.datasets.base_dataset import FramesTaskDataset
from llava.datasets.data_cfgs import data_configs
import pickle
from pathlib import Path
import random
import numpy as np
from llava.datasets.prompts import tt_caption_prompt, internvid_prompt
from llava.constants import DEFAULT_VIDEO_TOKEN
from PIL import Image
import json
import torch
import os
class LKVideoDataset(FramesTaskDataset):
def __init__(self, anno_path=None, data_args=None, fps=1.0, conv_type='multi', select_datasets=None, name='lk_video'):
self.default_fps = 1.0
self.fps = fps
self.conv_type = conv_type
self.select_datasets = select_datasets
self.annotation = self.get_dataset(anno_path)
#TODO: support single
assert self.conv_type in ('multi'), "lk_video conv type must be multi"
# assert hasattr(self.data_args, 'task_types') , "gpt4v_public must have key 'task_types' in yaml config"
# master_print(f"Finished loading dataset {name} {len(self.annotation)} samples...")
super().__init__(anno_path=anno_path,
data_args=data_args,
fps=fps,
name=name)
def __len__(self):
return len(self.annotation)
def get_dataset(self, anno_path):
anno_path = Path(anno_path)
with anno_path.open('rb') as f:
data = json.load(f)
if self.select_datasets is not None:
filtered_data = []
for sample in data:
video_path = Path(sample['video'])
dataset_name = video_path.parent.name
if dataset_name in self.select_datasets:
filtered_data.append(sample)
data = filtered_data
return data
def text_preprocess(self, item) -> List[Dict[str, str]]:
return item['conversations']
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
item = self.annotation[i]
ret = {
'images': self.vis_preprocess(item['video']),
'conversations': self.text_preprocess(item)
}
if 'id' in item:
ret['id'] = item['id']
return ret
@staticmethod
def _sample_frames(frames, num_segments):
indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int)
frames = [frames[ind] for ind in indices]
return frames
def vis_preprocess(self, vis_path):
image_files = []
for img_path in os.listdir(vis_path):
if img_path.endswith('.jpeg'):
img_idx = int(img_path.split('_')[-1][:-5])
image_files.append((img_idx, img_path))
image_files = sorted(image_files, key=lambda img: img[0])
# TODO: addhoc fix, only 10 frames
if len(image_files) > 10:
image_files = self._sample_frames(image_files, 10)
if self.num_segments > 0 and len(image_files) > self.num_segments:
image_files = self._sample_frames(image_files, self.num_segments)
images = []
for image_file in image_files:
try:
images.append(Image.open(os.path.join(vis_path, image_file[1])).convert('RGB'))
except Exception as e:
continue
formatted_images = []
for image in images:
im = self.preprocess_image(image)
if isinstance(im, list):
formatted_images.extend(im)
else:
formatted_images.append(im)
return formatted_images
@DATASETS.register_obj
def lk_video(data_args):
data_cfg = data_configs['lk_video']
fps, conv_type = data_args.external_args['fps'], data_args.external_args['conv_type']
select_datasets = data_args.external_args['select_datasets'] if 'select_datasets' in data_args.external_args else None
return LKVideoDataset(data_cfg['train_data_path'], data_args, fps, conv_type, select_datasets=select_datasets)
# if __name__ == '__main__':
# import json
# from tqdm import tqdm
# with open('/mnt/bn/liangkeg/data/xiangchen/finetune_all_detail_vidal200k_videollava_images_vid.json') as f:
# data = json.load(f)
# filterd_data = []
# for item in tqdm(data):
# image_path = item['video']
# if os.path.exists(image_path):
# filterd_data.append(item)
# else:
# print(image_path)
# with open('/mnt/bn/liangkeg/data/xiangchen/finetune_all_detail_vidal200k_videollava_images_vid.json', 'w') as f:
# json.dump(filterd_data, f)