import sys import datasets import torch import re import os import subprocess import numpy as np from llava.datasets.builder import DATASETS from typing import Dict, Optional, Sequence, List from llava.datasets.data_cfgs import data_configs from llava.datasets.base_dataset import FramesTaskDataset from llava.datasets.data_cfgs import data_configs from llava.utils import master_print import pickle from pathlib import Path import random from llava.datasets.prompts import tt_caption_prompt, internvid_prompt from llava.constants import DEFAULT_VIDEO_TOKEN from PIL import Image import json import numpy as np class GPT4VInternalDataset(FramesTaskDataset): def __init__(self, anno_path=None, data_args=None, fps=0.5, conv_type='single', task_types=None, name='gpt4v_internal'): self.default_fps = 2.0 self.fps = fps self.conv_type = conv_type self.task_types = task_types self.annotation = self.get_dataset(anno_path) assert self.conv_type in ('single', 'multi'), "gpt4v_public conv type must in single/multi" # assert hasattr(self.data_args, 'task_types') , "gpt4v_internal must have key 'task_types' in yaml config" # master_print(f"Finished loading dataset {name} {len(self.annotation)} samples...") super().__init__(anno_path=anno_path, data_args=data_args, fps=fps, name=name) def __len__(self): return len(self.annotation) def get_dataset(self, anno_path): dataset = [] anno_path = Path(anno_path) with anno_path.open('rb') as f: data = json.load(f) for info in data: filtered_qa = [] for qa in info['qa_pairs']: if len(qa['question']) == 0 or len(qa['answer']) == 0: continue filtered_qa.append(qa) info['qa_pairs'] = filtered_qa for task_type in self.task_types: info_task = info.copy() if len(info_task[task_type]) == 0: continue if task_type == 'qa_pairs' and self.conv_type == 'single': for qa_pair in info_task[task_type]: one_info = info_task.copy() one_info[task_type] = [qa_pair] one_info.update({ 'task_type': task_type }) dataset.append(one_info) else: info_task.update({ 'task_type': task_type }) dataset.append(info_task) return dataset @staticmethod def _sample_frames(frames, num_segments): indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int) frames = [frames[ind] for ind in indices] return frames def text_preprocess(self, item) -> List[Dict[str, str]]: all_convs = [] # TODO: different prompt for summary and detail if item['task_type'] == 'summary': all_convs.append([ { 'from': 'human', 'value': random.choice(internvid_prompt) }, { 'from': 'model', 'value': item['summary'] } ]) elif item['task_type'] == 'detail': all_convs.append([ { 'from': 'human', 'value': random.choice(tt_caption_prompt) }, { 'from': 'model', 'value': item['detail'] } ]) else: for qa in item['qa_pairs']: all_convs.append([ { 'from': 'human', 'value': qa['question'] }, { 'from': 'model', 'value': qa['answer'] } ]) conversations = [] random.shuffle(all_convs) for idx, conv in enumerate(all_convs): if idx == 0: conv[0]['value'] = DEFAULT_VIDEO_TOKEN + conv[0]['value'] conversations.extend(conv) return conversations def vis_preprocess(self, vis_path): image_files = [(os.path.splitext(img)[0], img) for img in os.listdir(vis_path) if not img.startswith('cuttime')] image_files = [(int(x[0]), x[1]) for x in image_files] image_files = sorted(image_files, key=lambda img: img[0]) intervals = np.linspace(start=0, stop=len(image_files)-1, num=10).astype(int) image_files = [image_files[i] for i in intervals] if self.num_segments > 0 and len(image_files) > self.num_segments: image_files = self._sample_frames(image_files, self.num_segments) images = [] for image_file in image_files: try: images.append(Image.open(os.path.join(vis_path, image_file[1])).convert('RGB')) except Exception as e: continue formatted_images = [] for image in images: im = self.preprocess_image(image) if isinstance(im, list): formatted_images.extend(im) else: formatted_images.append(im) # images = [self.preprocess_image(image) for image in images] return formatted_images def __getitem__(self, i) -> Dict[str, torch.Tensor]: item = self.annotation[i] ret = { 'images': self.vis_preprocess(item['vis_path']), 'conversations': self.text_preprocess(item) } if 'id' in item: ret['id'] = item['id'] return ret @DATASETS.register_obj def gpt4v_internal(data_args): data_cfg = data_configs['gpt4v_internal'] train_data_path = None if 'train_data_path' in data_args.external_args: train_data_path = data_args.external_args['train_data_path'] else: train_data_path = data_cfg['train_data_path'] fps, conv_type, task_types = data_args.external_args['fps'], data_args.external_args['conv_type'], data_args.external_args['task_types'] return GPT4VInternalDataset(train_data_path, data_args, fps, conv_type, task_types)