model1 / llava /datasets /gpt4v_tt_vqa_dataset.py
multitensor's picture
Upload folder using huggingface_hub
bbfa6f6 verified
import os
import json
import random
import json
from pathlib import Path
from llava.datasets.builder import DATASETS
from pathlib import Path
from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from llava.datasets.base_dataset import FramesTaskDataset
from llava.datasets.prompts import tt_caption_prompt, tt_caption_prompt2
from llava.constants import DEFAULT_VIDEO_TOKEN
from llava.utils import master_print
class GPT4VTTVqaDataset(FramesTaskDataset):
def __init__(self, anno_path, data_args=None, fps=0.5, conv_type='single', task_types=None, name='gpt4v_tt_vqa'):
self.default_fps = 0.5
self.fps = fps
self.conv_type = conv_type
self.task_types = task_types
self.annotation = self.get_dataset(anno_path)
assert self.conv_type in ('single', 'multi'), "gpt4v_tt_vqa conv type must in single/multi"
# assert hasattr(self.data_args, 'task_types'), "gpt4v_tt_vqa must have key 'task_types' in yaml config"
# master_print(f"Finished loading dataset {name} {len(self.annotation)} samples...")
super().__init__(anno_path=anno_path,
data_args=data_args,
fps=fps,
name=name)
def get_dataset(self, anno_path):
dataset = []
anno_path = Path(anno_path)
with anno_path.open('rb') as f:
data = json.load(f)
for info in data:
for task_type in self.task_types:
info_task = info.copy()
if task_type not in info or len(info_task[task_type]) == 0:
continue
if task_type == 'qas' and self.conv_type == 'single':
for qa_pair in info_task[task_type]:
one_info = info_task.copy()
one_info[task_type] = [qa_pair]
one_info.update({
'task_type': task_type
})
dataset.append(one_info)
else:
info_task.update({
'task_type': task_type
})
dataset.append(info_task)
return dataset
def text_preprocess(self, item) -> List[Dict[str, str]]:
all_convs = []
if hasattr(self.data_args, 'caption_prompt'):
cap_prompt = eval(self.data_args.caption_prompt)
else:
cap_prompt = tt_caption_prompt
if item['task_type'] == 'caption':
all_convs.append([
{
'from': 'human',
'value': random.choice(cap_prompt)
},
{
'from': 'model',
'value': item['caption']
}
])
else:
for idx, qa in enumerate(item['qas']):
all_convs.append([
{
'from': 'human',
'value': qa['q']
},
{
'from': 'model',
'value': qa['a']
}
])
conversations = []
random.shuffle(all_convs)
for idx, conv in enumerate(all_convs):
if idx == 0:
conv[0]['value'] = DEFAULT_VIDEO_TOKEN + conv[0]['value']
conversations.extend(conv)
return conversations
@DATASETS.register_obj
def gpt4v_tt_vqa(data_args):
anno_path = None
if 'train_data_path' in data_args.external_args:
anno_path = data_args.external_args['train_data_path']
else:
anno_path = data_configs["gpt4v_tt_vqa"]['train_data_path']
fps, conv_type, task_types = data_args.external_args['fps'], data_args.external_args['conv_type'], data_args.external_args['task_types']
return GPT4VTTVqaDataset(anno_path, data_args, fps, conv_type, task_types)