File size: 5,276 Bytes
bbfa6f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155

from llava.datasets.builder import DATASETS

from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from llava.datasets.base_dataset import FramesTaskDataset
from llava.datasets.data_cfgs import data_configs
import pickle
from pathlib import Path
import random
import numpy as np
from llava.datasets.prompts import tt_caption_prompt, internvid_prompt
from llava.constants import DEFAULT_VIDEO_TOKEN
from PIL import Image
import json
import torch
import os


class PromptV1Dataset(FramesTaskDataset):
    def __init__(self, anno_path=None, data_args=None, name='promptv1_2_internal', task_types=None):
        self.default_fps = 1.0
        self.task_types = task_types
        self.annotation = self.get_dataset(anno_path)
        super().__init__(anno_path=anno_path,
                         data_args=data_args,
                         name=name)
    def __len__(self):
        return len(self.annotation)


    def get_dataset(self, anno_path):
        dataset = []
        anno_path = Path(anno_path)
        with anno_path.open('rb') as f:
            data = json.load(f)
        for info in data:
            for task_type in self.task_types:
                info_task = info.copy()
                if task_type not in info or len(info_task[task_type]) == 0:
                    continue
                if task_type == 'qas' and self.conv_type == 'single':
                    for qa_pair in info_task[task_type]:
                        one_info = info_task.copy()
                        one_info[task_type] = [qa_pair]
                        one_info.update({
                            'task_type': task_type
                        })
                        dataset.append(one_info)
                else:
                    info_task.update({
                        'task_type': task_type
                    })
                    dataset.append(info_task)
        return dataset


    def text_preprocess(self, item) -> List[Dict[str, str]]:
        all_convs = []
        if hasattr(self.data_args, 'caption_prompt'):
            cap_prompt = eval(self.data_args.caption_prompt)
        else:
            cap_prompt = tt_caption_prompt
        if item['task_type'] == 'refine_caption':
            all_convs.append([
                {
                    'from': 'human',
                    'value': random.choice(cap_prompt)
                },
                {
                    'from': 'model',
                    'value': item['refine_caption']
                }
            ])
        else:
            for idx, qa in enumerate(item['qas']):
                all_convs.append([
                    {
                        'from': 'human',
                        'value': qa['q']
                    },
                    {
                        'from': 'model',
                        'value': qa['a']
                    }
                ])

        conversations = []
        random.shuffle(all_convs)
        for idx, conv in enumerate(all_convs):
            if idx == 0:
                conv[0]['value'] = DEFAULT_VIDEO_TOKEN + conv[0]['value']
            conversations.extend(conv)
        return conversations



    # def __getitem__(self, i) -> Dict[str, torch.Tensor]:
    #     item = self.annotation[i]

    #     ret = {
    #         'images': self.vis_preprocess(item['video_path']),
    #         'conversations': self.text_preprocess(item)
    #     }
    #     if 'id' in item:
    #         ret['id'] = item['id']

    #     return ret


    # @staticmethod
    # def _sample_frames(frames, num_segments):
    #     indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int)

    #     frames = [frames[ind] for ind in indices]

    #     return frames

    # def vis_preprocess(self, vis_path):
    #     image_files = []
    #     for img_path in os.listdir(vis_path):
    #         if img_path.endswith('.jpeg'):
    #             img_idx = int(img_path.split('_')[-1][:-5])
    #             image_files.append((img_idx, img_path))
        
    #     image_files = sorted(image_files, key=lambda img: img[0])
    #     # TODO: addhoc fix,  only 10 frames
    #     if len(image_files) > 10:
    #         image_files = self._sample_frames(image_files, 10)
    #     if self.num_segments > 0 and len(image_files) > self.num_segments:
    #         image_files = self._sample_frames(image_files, self.num_segments)
        
    #     images = []
    #     for image_file in image_files:
    #         try:
    #             images.append(Image.open(os.path.join(vis_path, image_file[1])).convert('RGB'))
    #         except Exception as e:
    #             continue
    #     formatted_images = []
    #     for image in images:
    #         im = self.preprocess_image(image)
    #         if isinstance(im, list):
    #             formatted_images.extend(im)
    #         else:
    #             formatted_images.append(im)
    #     return formatted_images


@DATASETS.register_obj
def promptv1_2_internal(data_args):
    data_cfg = data_configs['promptv1_2_internal']
    task_types = data_args.external_args['task_types']
    return PromptV1Dataset(anno_path=data_cfg['train_data_path'], data_args=data_args, task_types=task_types)