File size: 3,995 Bytes
bbfa6f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import os
import json
import random
import json
from pathlib import Path
from llava.datasets.builder import DATASETS
from pathlib import Path
from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from llava.datasets.base_dataset import FramesTaskDataset
from llava.datasets.prompts import tt_caption_prompt, tt_caption_prompt2
from llava.constants import DEFAULT_VIDEO_TOKEN
from llava.utils import master_print
class GPT4VTTVqaDataset(FramesTaskDataset):
def __init__(self, anno_path, data_args=None, fps=0.5, conv_type='single', task_types=None, name='gpt4v_tt_vqa'):
self.default_fps = 0.5
self.fps = fps
self.conv_type = conv_type
self.task_types = task_types
self.annotation = self.get_dataset(anno_path)
assert self.conv_type in ('single', 'multi'), "gpt4v_tt_vqa conv type must in single/multi"
# assert hasattr(self.data_args, 'task_types'), "gpt4v_tt_vqa must have key 'task_types' in yaml config"
# master_print(f"Finished loading dataset {name} {len(self.annotation)} samples...")
super().__init__(anno_path=anno_path,
data_args=data_args,
fps=fps,
name=name)
def get_dataset(self, anno_path):
dataset = []
anno_path = Path(anno_path)
with anno_path.open('rb') as f:
data = json.load(f)
for info in data:
for task_type in self.task_types:
info_task = info.copy()
if task_type not in info or len(info_task[task_type]) == 0:
continue
if task_type == 'qas' and self.conv_type == 'single':
for qa_pair in info_task[task_type]:
one_info = info_task.copy()
one_info[task_type] = [qa_pair]
one_info.update({
'task_type': task_type
})
dataset.append(one_info)
else:
info_task.update({
'task_type': task_type
})
dataset.append(info_task)
return dataset
def text_preprocess(self, item) -> List[Dict[str, str]]:
all_convs = []
if hasattr(self.data_args, 'caption_prompt'):
cap_prompt = eval(self.data_args.caption_prompt)
else:
cap_prompt = tt_caption_prompt
if item['task_type'] == 'caption':
all_convs.append([
{
'from': 'human',
'value': random.choice(cap_prompt)
},
{
'from': 'model',
'value': item['caption']
}
])
else:
for idx, qa in enumerate(item['qas']):
all_convs.append([
{
'from': 'human',
'value': qa['q']
},
{
'from': 'model',
'value': qa['a']
}
])
conversations = []
random.shuffle(all_convs)
for idx, conv in enumerate(all_convs):
if idx == 0:
conv[0]['value'] = DEFAULT_VIDEO_TOKEN + conv[0]['value']
conversations.extend(conv)
return conversations
@DATASETS.register_obj
def gpt4v_tt_vqa(data_args):
anno_path = None
if 'train_data_path' in data_args.external_args:
anno_path = data_args.external_args['train_data_path']
else:
anno_path = data_configs["gpt4v_tt_vqa"]['train_data_path']
fps, conv_type, task_types = data_args.external_args['fps'], data_args.external_args['conv_type'], data_args.external_args['task_types']
return GPT4VTTVqaDataset(anno_path, data_args, fps, conv_type, task_types)
|