File size: 3,995 Bytes
bbfa6f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import json
import random
import json
from pathlib import Path
from llava.datasets.builder import DATASETS
from pathlib import Path
from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from llava.datasets.base_dataset import FramesTaskDataset
from llava.datasets.prompts import tt_caption_prompt, tt_caption_prompt2
from llava.constants import DEFAULT_VIDEO_TOKEN
from llava.utils import master_print


class GPT4VTTVqaDataset(FramesTaskDataset):
    def __init__(self, anno_path, data_args=None, fps=0.5, conv_type='single', task_types=None, name='gpt4v_tt_vqa'):
        self.default_fps = 0.5
        self.fps = fps
        self.conv_type = conv_type
        self.task_types = task_types
        self.annotation = self.get_dataset(anno_path)
        assert self.conv_type in ('single', 'multi'), "gpt4v_tt_vqa conv type must in single/multi"
        # assert hasattr(self.data_args, 'task_types'), "gpt4v_tt_vqa must have key 'task_types' in yaml config"
        # master_print(f"Finished loading dataset {name} {len(self.annotation)} samples...")
        super().__init__(anno_path=anno_path,
                         data_args=data_args,
                         fps=fps,
                         name=name)
    def get_dataset(self, anno_path):
        dataset = []
        anno_path = Path(anno_path)
        with anno_path.open('rb') as f:
            data = json.load(f)
        for info in data:
            for task_type in self.task_types:
                info_task = info.copy()
                if task_type not in info or len(info_task[task_type]) == 0:
                    continue
                if task_type == 'qas' and self.conv_type == 'single':
                    for qa_pair in info_task[task_type]:
                        one_info = info_task.copy()
                        one_info[task_type] = [qa_pair]
                        one_info.update({
                            'task_type': task_type
                        })
                        dataset.append(one_info)
                else:
                    info_task.update({
                        'task_type': task_type
                    })
                    dataset.append(info_task)
        return dataset


    def text_preprocess(self, item) -> List[Dict[str, str]]:
        all_convs = []
        if hasattr(self.data_args, 'caption_prompt'):
            cap_prompt = eval(self.data_args.caption_prompt)
        else:
            cap_prompt = tt_caption_prompt
        if item['task_type'] == 'caption':
            all_convs.append([
                {
                    'from': 'human',
                    'value': random.choice(cap_prompt)
                },
                {
                    'from': 'model',
                    'value': item['caption']
                }
            ])
        else:
            for idx, qa in enumerate(item['qas']):
                all_convs.append([
                    {
                        'from': 'human',
                        'value': qa['q']
                    },
                    {
                        'from': 'model',
                        'value': qa['a']
                    }
                ])

        conversations = []
        random.shuffle(all_convs)
        for idx, conv in enumerate(all_convs):
            if idx == 0:
                conv[0]['value'] = DEFAULT_VIDEO_TOKEN + conv[0]['value']
            conversations.extend(conv)
        return conversations

    

@DATASETS.register_obj
def gpt4v_tt_vqa(data_args):
    anno_path = None
    if 'train_data_path' in data_args.external_args:
        anno_path = data_args.external_args['train_data_path']
    else:
        anno_path = data_configs["gpt4v_tt_vqa"]['train_data_path']
    fps, conv_type, task_types = data_args.external_args['fps'], data_args.external_args['conv_type'], data_args.external_args['task_types']
    return GPT4VTTVqaDataset(anno_path, data_args, fps, conv_type, task_types)