from llava.datasets.builder import DATASETS from typing import Dict, Optional, Sequence, List from llava.datasets.data_cfgs import data_configs from llava.datasets.base_dataset import FramesTaskDataset from llava.datasets.data_cfgs import data_configs import pickle from pathlib import Path import random import numpy as np from llava.datasets.prompts import tt_caption_prompt, internvid_prompt from llava.constants import DEFAULT_VIDEO_TOKEN from PIL import Image import json import torch import os class LKVideoDataset(FramesTaskDataset): def __init__(self, anno_path=None, data_args=None, fps=1.0, conv_type='multi', select_datasets=None, name='lk_video'): self.default_fps = 1.0 self.fps = fps self.conv_type = conv_type self.select_datasets = select_datasets self.annotation = self.get_dataset(anno_path) #TODO: support single assert self.conv_type in ('multi'), "lk_video conv type must be multi" # assert hasattr(self.data_args, 'task_types') , "gpt4v_public must have key 'task_types' in yaml config" # master_print(f"Finished loading dataset {name} {len(self.annotation)} samples...") super().__init__(anno_path=anno_path, data_args=data_args, fps=fps, name=name) def __len__(self): return len(self.annotation) def get_dataset(self, anno_path): anno_path = Path(anno_path) with anno_path.open('rb') as f: data = json.load(f) if self.select_datasets is not None: filtered_data = [] for sample in data: video_path = Path(sample['video']) dataset_name = video_path.parent.name if dataset_name in self.select_datasets: filtered_data.append(sample) data = filtered_data return data def text_preprocess(self, item) -> List[Dict[str, str]]: return item['conversations'] def __getitem__(self, i) -> Dict[str, torch.Tensor]: item = self.annotation[i] ret = { 'images': self.vis_preprocess(item['video']), 'conversations': self.text_preprocess(item) } if 'id' in item: ret['id'] = item['id'] return ret @staticmethod def _sample_frames(frames, num_segments): indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int) frames = [frames[ind] for ind in indices] return frames def vis_preprocess(self, vis_path): image_files = [] for img_path in os.listdir(vis_path): if img_path.endswith('.jpeg'): img_idx = int(img_path.split('_')[-1][:-5]) image_files.append((img_idx, img_path)) image_files = sorted(image_files, key=lambda img: img[0]) # TODO: addhoc fix, only 10 frames if len(image_files) > 10: image_files = self._sample_frames(image_files, 10) if self.num_segments > 0 and len(image_files) > self.num_segments: image_files = self._sample_frames(image_files, self.num_segments) images = [] for image_file in image_files: try: images.append(Image.open(os.path.join(vis_path, image_file[1])).convert('RGB')) except Exception as e: continue formatted_images = [] for image in images: im = self.preprocess_image(image) if isinstance(im, list): formatted_images.extend(im) else: formatted_images.append(im) return formatted_images @DATASETS.register_obj def lk_video(data_args): data_cfg = data_configs['lk_video'] fps, conv_type = data_args.external_args['fps'], data_args.external_args['conv_type'] select_datasets = data_args.external_args['select_datasets'] if 'select_datasets' in data_args.external_args else None return LKVideoDataset(data_cfg['train_data_path'], data_args, fps, conv_type, select_datasets=select_datasets) # if __name__ == '__main__': # import json # from tqdm import tqdm # with open('/mnt/bn/liangkeg/data/xiangchen/finetune_all_detail_vidal200k_videollava_images_vid.json') as f: # data = json.load(f) # filterd_data = [] # for item in tqdm(data): # image_path = item['video'] # if os.path.exists(image_path): # filterd_data.append(item) # else: # print(image_path) # with open('/mnt/bn/liangkeg/data/xiangchen/finetune_all_detail_vidal200k_videollava_images_vid.json', 'w') as f: # json.dump(filterd_data, f)