File size: 2,322 Bytes
bbfa6f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import os
import random
import json
from pathlib import Path
from llava.datasets.builder import DATASETS
from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from llava.datasets.base_dataset import FramesTaskDataset
from llava.datasets.prompts import tt_caption_prompt, tt_caption_prompt2
from llava.constants import DEFAULT_VIDEO_TOKEN
class TTVqaDataset(FramesTaskDataset):
def __init__(self, anno_path, data_args=None, fps=2.0, data_cfgs=None, name='tt_vqa'):
super().__init__(anno_path=anno_path,
data_args=data_args,
fps=fps,
name=name)
self.default_fps = data_cfgs['fps']
def text_preprocess(self, item) -> List[Dict[str, str]]:
all_convs = []
if hasattr(self.data_args, 'caption_prompt'):
cap_prompt = eval(self.data_args.caption_prompt)
else:
cap_prompt = tt_caption_prompt
if 'caption' in item:
all_convs.append([
{
'from': 'human',
'value': random.choice(cap_prompt)
},
{
'from': 'model',
'value': item['caption']
}
])
if 'qas' in item:
for idx, qa in enumerate(item['qas']):
all_convs.append([
{
'from': 'human',
'value': qa['q']
},
{
'from': 'model',
'value': qa['a']
}
])
conversations = []
random.shuffle(all_convs)
for idx, conv in enumerate(all_convs):
if idx == 0:
conv[0]['value'] = DEFAULT_VIDEO_TOKEN + conv[0]['value']
conversations.extend(conv)
return conversations
@DATASETS.register_obj
def tt_vqa(data_args):
train_data_path = None
if 'train_data_path' in data_args.external_args:
train_data_path = data_args.external_args['train_data_path']
else:
train_data_path = data_configs["tt_vqa"]['train_data_path']
return TTVqaDataset(train_data_path, data_args, 2.0, data_configs["tt_vqa"])
|