model1 / llava /datasets /tt_vqa_dataset.py
multitensor's picture
Upload folder using huggingface_hub
bbfa6f6 verified
import os
import random
import json
from pathlib import Path
from llava.datasets.builder import DATASETS
from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from llava.datasets.base_dataset import FramesTaskDataset
from llava.datasets.prompts import tt_caption_prompt, tt_caption_prompt2
from llava.constants import DEFAULT_VIDEO_TOKEN
class TTVqaDataset(FramesTaskDataset):
def __init__(self, anno_path, data_args=None, fps=2.0, data_cfgs=None, name='tt_vqa'):
super().__init__(anno_path=anno_path,
data_args=data_args,
fps=fps,
name=name)
self.default_fps = data_cfgs['fps']
def text_preprocess(self, item) -> List[Dict[str, str]]:
all_convs = []
if hasattr(self.data_args, 'caption_prompt'):
cap_prompt = eval(self.data_args.caption_prompt)
else:
cap_prompt = tt_caption_prompt
if 'caption' in item:
all_convs.append([
{
'from': 'human',
'value': random.choice(cap_prompt)
},
{
'from': 'model',
'value': item['caption']
}
])
if 'qas' in item:
for idx, qa in enumerate(item['qas']):
all_convs.append([
{
'from': 'human',
'value': qa['q']
},
{
'from': 'model',
'value': qa['a']
}
])
conversations = []
random.shuffle(all_convs)
for idx, conv in enumerate(all_convs):
if idx == 0:
conv[0]['value'] = DEFAULT_VIDEO_TOKEN + conv[0]['value']
conversations.extend(conv)
return conversations
@DATASETS.register_obj
def tt_vqa(data_args):
train_data_path = None
if 'train_data_path' in data_args.external_args:
train_data_path = data_args.external_args['train_data_path']
else:
train_data_path = data_configs["tt_vqa"]['train_data_path']
return TTVqaDataset(train_data_path, data_args, 2.0, data_configs["tt_vqa"])