from llava.datasets.builder import DATASETS from typing import Dict, Optional, Sequence, List from llava.datasets.data_cfgs import data_configs from llava.datasets.base_dataset import FramesTaskDataset from llava.datasets.data_cfgs import data_configs import pickle from pathlib import Path import random import numpy as np from llava.datasets.prompts import tt_caption_prompt, internvid_prompt from llava.constants import DEFAULT_VIDEO_TOKEN from PIL import Image import json import torch import os class GPT4VPublicDataset(FramesTaskDataset): def __init__(self, anno_path=None, data_args=None, fps=1.0, conv_type='single', task_types=None, sample_method='uniform', name='gpt4v_public'): self.default_fps = 1.0 self.fps = fps self.conv_type = conv_type self.task_types = task_types self.annotation = self.get_dataset(anno_path) self.sample_method = sample_method assert self.conv_type in ('single', 'multi'), "gpt4v_public conv type must in single/multi" assert self.sample_method in ('sequential', 'uniform'), "gpt4v_public sample method must in sequential/uniform" # assert hasattr(self.data_args, 'task_types') , "gpt4v_public must have key 'task_types' in yaml config" # master_print(f"Finished loading dataset {name} {len(self.annotation)} samples...") super().__init__(anno_path=anno_path, data_args=data_args, fps=fps, name=name) def __len__(self): return len(self.annotation) def get_dataset(self, anno_path): dataset = [] anno_path = Path(anno_path) with anno_path.open('rb') as f: data = json.load(f) for info in data: filtered_qa = [] if 'qa_pairs' not in info: index = 0 while index < len(info['conversation']): if len(info['conversation'][index].strip()) == 0: index += 1 continue if 'C' in info['conversation'][index]: if index+1 < len(info['conversation']) and 'A' in info['conversation'][index+1]: filtered_qa.append( [info['conversation'][index], info['conversation'][index+1]] ) index += 2 else: index += 1 continue else: # print(info['conversation'][index]) index += 1 continue else: for qa in info['qa_pairs']: if len(qa[0]) == 0 or len(qa[1]) == 0: continue filtered_qa.append(qa) info['qa_pairs'] = filtered_qa for task_type in self.task_types: info_task = info.copy() if len(info_task[task_type]) == 0: continue if task_type == 'qa_pairs' and self.conv_type == 'single': for qa_pair in info_task[task_type]: one_info = info_task.copy() one_info[task_type] = [qa_pair] one_info.update({ 'task_type': task_type }) dataset.append(one_info) else: info_task.update({ 'task_type': task_type }) dataset.append(info_task) return dataset # @staticmethod # def _sample_frames(frames, num_segments): # indices = list(range(num_segments)) # frames = [frames[ind] for ind in indices] # return frames def text_preprocess(self, item) -> List[Dict[str, str]]: all_convs = [] # TODO: different prompt for summary and detail if item['task_type'] == 'summary': summary = '' if isinstance(item['summary'], list): for s in item['summary']: if len(s.strip()) != 0: summary = s break else: summary = item['summary'] all_convs.append([ { 'from': 'human', 'value': random.choice(internvid_prompt) }, { 'from': 'model', 'value': summary } ]) elif item['task_type'] == 'detail': detail = '' if isinstance(item['detail'], list): for s in item['detail']: if len(s.strip()) != 0: detail = s break else: detail = item['detail'] all_convs.append([ { 'from': 'human', 'value': random.choice(tt_caption_prompt) }, { 'from': 'model', 'value': detail } ]) else: for qa in item['qa_pairs']: all_convs.append([ { 'from': 'human', 'value': qa[0] }, { 'from': 'model', 'value': qa[1] } ]) conversations = [] random.shuffle(all_convs) for idx, conv in enumerate(all_convs): if idx == 0: conv[0]['value'] = DEFAULT_VIDEO_TOKEN + conv[0]['value'] conversations.extend(conv) return conversations def __getitem__(self, i) -> Dict[str, torch.Tensor]: item = self.annotation[i] ret = { 'images': self.vis_preprocess(item['vis_path']), 'conversations': self.text_preprocess(item) } if 'id' in item: ret['id'] = item['id'] return ret def _sample_frames(self, frames, num_segments, preprocess=False): if preprocess: if self.sample_method == 'uniform': indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int) elif self.sample_method == 'sequential': indices = range(10) else: raise NotImplementedError frames = [frames[ind] for ind in indices] else: indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int) frames = [frames[ind] for ind in indices] return frames def vis_preprocess(self, vis_path): image_files = [] for img_path in os.listdir(vis_path): if img_path.endswith('.jpeg'): img_idx = int(img_path.split('_')[-1][:-5]) image_files.append((img_idx, img_path)) image_files = sorted(image_files, key=lambda img: img[0]) # TODO: addhoc fix, only 10 frames if len(image_files) > 10: image_files = self._sample_frames(image_files, 10, preprocess=True) if self.num_segments > 0 and len(image_files) > self.num_segments: image_files = self._sample_frames(image_files, self.num_segments) images = [] for image_file in image_files: try: images.append(Image.open(os.path.join(vis_path, image_file[1])).convert('RGB')) except Exception as e: continue formatted_images = [] for image in images: im = self.preprocess_image(image) if isinstance(im, list): formatted_images.extend(im) else: formatted_images.append(im) return formatted_images @DATASETS.register_obj def gpt4v_public(data_args): data_cfg = data_configs['gpt4v_public'] if 'train_data_path' in data_args.external_args: data_cfg['train_data_path'] = data_args.external_args['train_data_path'] anno_path = data_cfg['train_data_path'] fps, conv_type, task_types = data_args.external_args['fps'], data_args.external_args['conv_type'], data_args.external_args['task_types'] if 'sample_method' in data_args.external_args: sample_method = data_args.external_args['sample_method'] else: sample_method = 'uniform' return GPT4VPublicDataset(anno_path, data_args, fps, conv_type, task_types, sample_method) if __name__ == '__main__': pass # import pickle # from tqdm import tqdm # file_paths = ['/mnt/bn/algo-masp-nas-2/xianyang/clean_annotations/annotations/webvid10m', # '/mnt/bn/algo-masp-nas-2/xianyang/clean_annotations/annotations/webvid2m'] # frame_paths = ['/mnt/bn/algo-masp-nas-2/xianyang/clean_annotations/frames/webvid10m', # '/mnt/bn/algo-masp-nas-2/xianyang/clean_annotations/frames/webvid2m'] # data = [] # for file_path, frame_path in zip(file_paths, frame_paths): # file_path = Path(file_path) # for pkl_path in tqdm(file_path.glob('*')): # with pkl_path.open('rb') as f: # info = pickle.load(f) # pkl_name = pkl_path.name[:-4] # frame_folder_path = Path(frame_path) / pkl_name # info['vis_path'] = str(frame_folder_path) # if os.path.exists(info['vis_path']): # data.append(info) # with open ('/mnt/bn/algo-masp-nas-2/xiangchen/data/shared_gpt4v_data/data_500k.json', 'w') as f: # json.dump(data, f) # if frame_path.exists(): # print(1) # with open('/mnt/bn/liangkeg/data/xiangchen/finetune_all_detail_vidal200k_videollava_images.json') as f: # data = json.load(f) # data_im = [] # data_vid = [] # for sample in data: # if 'image' in sample: # data_im.append(sample) # else: # data_vid.append(sample) # with open('/mnt/bn/liangkeg/data/xiangchen/finetune_all_detail_vidal200k_videollava_images_im.json', 'w') as f: # json.dump(data_im, f) # with open('/mnt/bn/liangkeg/data/xiangchen/finetune_all_detail_vidal200k_videollava_images_vid.json', 'w') as f: # json.dump(data_vid, f)