|
import os |
|
import json |
|
import random |
|
import json |
|
from pathlib import Path |
|
from llava.datasets.builder import DATASETS |
|
from pathlib import Path |
|
from typing import Dict, Optional, Sequence, List |
|
from llava.datasets.data_cfgs import data_configs |
|
from llava.datasets.base_dataset import FramesTaskDataset |
|
from llava.datasets.prompts import tt_caption_prompt, tt_caption_prompt2 |
|
from llava.constants import DEFAULT_VIDEO_TOKEN |
|
from llava.utils import master_print |
|
|
|
|
|
class GPT4VTTVqaDataset(FramesTaskDataset): |
|
def __init__(self, anno_path, data_args=None, fps=0.5, conv_type='single', task_types=None, name='gpt4v_tt_vqa'): |
|
self.default_fps = 0.5 |
|
self.fps = fps |
|
self.conv_type = conv_type |
|
self.task_types = task_types |
|
self.annotation = self.get_dataset(anno_path) |
|
assert self.conv_type in ('single', 'multi'), "gpt4v_tt_vqa conv type must in single/multi" |
|
|
|
|
|
super().__init__(anno_path=anno_path, |
|
data_args=data_args, |
|
fps=fps, |
|
name=name) |
|
def get_dataset(self, anno_path): |
|
dataset = [] |
|
anno_path = Path(anno_path) |
|
with anno_path.open('rb') as f: |
|
data = json.load(f) |
|
for info in data: |
|
for task_type in self.task_types: |
|
info_task = info.copy() |
|
if task_type not in info or len(info_task[task_type]) == 0: |
|
continue |
|
if task_type == 'qas' and self.conv_type == 'single': |
|
for qa_pair in info_task[task_type]: |
|
one_info = info_task.copy() |
|
one_info[task_type] = [qa_pair] |
|
one_info.update({ |
|
'task_type': task_type |
|
}) |
|
dataset.append(one_info) |
|
else: |
|
info_task.update({ |
|
'task_type': task_type |
|
}) |
|
dataset.append(info_task) |
|
return dataset |
|
|
|
|
|
def text_preprocess(self, item) -> List[Dict[str, str]]: |
|
all_convs = [] |
|
if hasattr(self.data_args, 'caption_prompt'): |
|
cap_prompt = eval(self.data_args.caption_prompt) |
|
else: |
|
cap_prompt = tt_caption_prompt |
|
if item['task_type'] == 'caption': |
|
all_convs.append([ |
|
{ |
|
'from': 'human', |
|
'value': random.choice(cap_prompt) |
|
}, |
|
{ |
|
'from': 'model', |
|
'value': item['caption'] |
|
} |
|
]) |
|
else: |
|
for idx, qa in enumerate(item['qas']): |
|
all_convs.append([ |
|
{ |
|
'from': 'human', |
|
'value': qa['q'] |
|
}, |
|
{ |
|
'from': 'model', |
|
'value': qa['a'] |
|
} |
|
]) |
|
|
|
conversations = [] |
|
random.shuffle(all_convs) |
|
for idx, conv in enumerate(all_convs): |
|
if idx == 0: |
|
conv[0]['value'] = DEFAULT_VIDEO_TOKEN + conv[0]['value'] |
|
conversations.extend(conv) |
|
return conversations |
|
|
|
|
|
|
|
@DATASETS.register_obj |
|
def gpt4v_tt_vqa(data_args): |
|
anno_path = None |
|
if 'train_data_path' in data_args.external_args: |
|
anno_path = data_args.external_args['train_data_path'] |
|
else: |
|
anno_path = data_configs["gpt4v_tt_vqa"]['train_data_path'] |
|
fps, conv_type, task_types = data_args.external_args['fps'], data_args.external_args['conv_type'], data_args.external_args['task_types'] |
|
return GPT4VTTVqaDataset(anno_path, data_args, fps, conv_type, task_types) |
|
|
|
|