from llava.datasets.builder import DATASETS from typing import Dict, Optional, Sequence, List from llava.datasets.data_cfgs import data_configs from llava.datasets.base_dataset import FramesTaskDataset from llava.datasets.data_cfgs import data_configs import pickle from pathlib import Path import random import numpy as np from llava.datasets.prompts import tt_caption_prompt, internvid_prompt from llava.constants import DEFAULT_VIDEO_TOKEN from PIL import Image import json import torch import os class PromptV1Dataset(FramesTaskDataset): def __init__(self, anno_path=None, data_args=None, name='promptv1_2_internal', task_types=None): self.default_fps = 1.0 self.task_types = task_types self.annotation = self.get_dataset(anno_path) super().__init__(anno_path=anno_path, data_args=data_args, name=name) def __len__(self): return len(self.annotation) def get_dataset(self, anno_path): dataset = [] anno_path = Path(anno_path) with anno_path.open('rb') as f: data = json.load(f) for info in data: for task_type in self.task_types: info_task = info.copy() if task_type not in info or len(info_task[task_type]) == 0: continue if task_type == 'qas' and self.conv_type == 'single': for qa_pair in info_task[task_type]: one_info = info_task.copy() one_info[task_type] = [qa_pair] one_info.update({ 'task_type': task_type }) dataset.append(one_info) else: info_task.update({ 'task_type': task_type }) dataset.append(info_task) return dataset def text_preprocess(self, item) -> List[Dict[str, str]]: all_convs = [] if hasattr(self.data_args, 'caption_prompt'): cap_prompt = eval(self.data_args.caption_prompt) else: cap_prompt = tt_caption_prompt if item['task_type'] == 'refine_caption': all_convs.append([ { 'from': 'human', 'value': random.choice(cap_prompt) }, { 'from': 'model', 'value': item['refine_caption'] } ]) else: for idx, qa in enumerate(item['qas']): all_convs.append([ { 'from': 'human', 'value': qa['q'] }, { 'from': 'model', 'value': qa['a'] } ]) conversations = [] random.shuffle(all_convs) for idx, conv in enumerate(all_convs): if idx == 0: conv[0]['value'] = DEFAULT_VIDEO_TOKEN + conv[0]['value'] conversations.extend(conv) return conversations # def __getitem__(self, i) -> Dict[str, torch.Tensor]: # item = self.annotation[i] # ret = { # 'images': self.vis_preprocess(item['video_path']), # 'conversations': self.text_preprocess(item) # } # if 'id' in item: # ret['id'] = item['id'] # return ret # @staticmethod # def _sample_frames(frames, num_segments): # indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int) # frames = [frames[ind] for ind in indices] # return frames # def vis_preprocess(self, vis_path): # image_files = [] # for img_path in os.listdir(vis_path): # if img_path.endswith('.jpeg'): # img_idx = int(img_path.split('_')[-1][:-5]) # image_files.append((img_idx, img_path)) # image_files = sorted(image_files, key=lambda img: img[0]) # # TODO: addhoc fix, only 10 frames # if len(image_files) > 10: # image_files = self._sample_frames(image_files, 10) # if self.num_segments > 0 and len(image_files) > self.num_segments: # image_files = self._sample_frames(image_files, self.num_segments) # images = [] # for image_file in image_files: # try: # images.append(Image.open(os.path.join(vis_path, image_file[1])).convert('RGB')) # except Exception as e: # continue # formatted_images = [] # for image in images: # im = self.preprocess_image(image) # if isinstance(im, list): # formatted_images.extend(im) # else: # formatted_images.append(im) # return formatted_images @DATASETS.register_obj def promptv1_2_internal(data_args): data_cfg = data_configs['promptv1_2_internal'] task_types = data_args.external_args['task_types'] return PromptV1Dataset(anno_path=data_cfg['train_data_path'], data_args=data_args, task_types=task_types)