import os import random import json from pathlib import Path from llava.datasets.builder import DATASETS from typing import Dict, Optional, Sequence, List from llava.datasets.data_cfgs import data_configs from llava.datasets.base_dataset import FramesTaskDataset from llava.datasets.prompts import tt_caption_prompt, tt_caption_prompt2 from llava.constants import DEFAULT_VIDEO_TOKEN class TTVqaDataset(FramesTaskDataset): def __init__(self, anno_path, data_args=None, fps=2.0, data_cfgs=None, name='tt_vqa'): super().__init__(anno_path=anno_path, data_args=data_args, fps=fps, name=name) self.default_fps = data_cfgs['fps'] def text_preprocess(self, item) -> List[Dict[str, str]]: all_convs = [] if hasattr(self.data_args, 'caption_prompt'): cap_prompt = eval(self.data_args.caption_prompt) else: cap_prompt = tt_caption_prompt if 'caption' in item: all_convs.append([ { 'from': 'human', 'value': random.choice(cap_prompt) }, { 'from': 'model', 'value': item['caption'] } ]) if 'qas' in item: for idx, qa in enumerate(item['qas']): all_convs.append([ { 'from': 'human', 'value': qa['q'] }, { 'from': 'model', 'value': qa['a'] } ]) conversations = [] random.shuffle(all_convs) for idx, conv in enumerate(all_convs): if idx == 0: conv[0]['value'] = DEFAULT_VIDEO_TOKEN + conv[0]['value'] conversations.extend(conv) return conversations @DATASETS.register_obj def tt_vqa(data_args): train_data_path = None if 'train_data_path' in data_args.external_args: train_data_path = data_args.external_args['train_data_path'] else: train_data_path = data_configs["tt_vqa"]['train_data_path'] return TTVqaDataset(train_data_path, data_args, 2.0, data_configs["tt_vqa"])