File size: 2,322 Bytes
bbfa6f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import random
import json
from pathlib import Path
from llava.datasets.builder import DATASETS

from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from llava.datasets.base_dataset import FramesTaskDataset
from llava.datasets.prompts import tt_caption_prompt, tt_caption_prompt2
from llava.constants import DEFAULT_VIDEO_TOKEN


class TTVqaDataset(FramesTaskDataset):
    def __init__(self, anno_path, data_args=None, fps=2.0, data_cfgs=None, name='tt_vqa'):
        super().__init__(anno_path=anno_path,
                         data_args=data_args,
                         fps=fps,
                         name=name)
        self.default_fps = data_cfgs['fps']


    def text_preprocess(self, item) -> List[Dict[str, str]]:
        all_convs = []
        if hasattr(self.data_args, 'caption_prompt'):
            cap_prompt = eval(self.data_args.caption_prompt)
        else:
            cap_prompt = tt_caption_prompt
        if 'caption' in item:
            all_convs.append([
                {
                    'from': 'human',
                    'value': random.choice(cap_prompt)
                },
                {
                    'from': 'model',
                    'value': item['caption']
                }
            ])
        if 'qas' in item:
            for idx, qa in enumerate(item['qas']):
                all_convs.append([
                    {
                        'from': 'human',
                        'value': qa['q']
                    },
                    {
                        'from': 'model',
                        'value': qa['a']
                    }
                ])

        conversations = []
        random.shuffle(all_convs)
        for idx, conv in enumerate(all_convs):
            if idx == 0:
                conv[0]['value'] = DEFAULT_VIDEO_TOKEN + conv[0]['value']
            conversations.extend(conv)

        return conversations


@DATASETS.register_obj
def tt_vqa(data_args):
    train_data_path = None
    if 'train_data_path' in data_args.external_args:
        train_data_path = data_args.external_args['train_data_path']
    else:
        train_data_path = data_configs["tt_vqa"]['train_data_path']
    return TTVqaDataset(train_data_path, data_args, 2.0, data_configs["tt_vqa"])