File size: 8,424 Bytes
bbfa6f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import os
import json
import sys
import copy
import math
import torch
import decord
import random
import numpy as np
from PIL import Image
from decord import VideoReader
from torch.utils.data import Dataset
from llava.utils import master_print
from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from transformers import CLIPImageProcessor, SiglipImageProcessor

from llava.mm_utils import get_frame_indices, process_anyres_image
from torch.utils.data.dataloader import default_collate

decord.bridge.set_bridge("torch")

class TaskBaseDataset(Dataset):
    """ Implementation of base task dataset """
    def __init__(self, anno_path=None, data_args=None, name=None, **kwargs):

        self.anno_path = anno_path
        self.data_args = data_args
        self.image_aspect_ratio = data_args.image_aspect_ratio
        self.image_grid_pinpoints = data_args.image_grid_pinpoints
        self.vis_processor = data_args.image_processor
        self.type = None
        self.name = name

        master_print(f"Loading dataset {name}...")
        if (anno_path is not None):
            if not hasattr(self, 'annotation'):
                self.annotation = json.load(open(anno_path, 'r'))
            master_print(f"Finish loading dataset {name} {len(self.annotation)} samples...")

    def __len__(self):
        return len(self.annotation)

    def collater(self, samples):
        return default_collate(samples)

    def text_preprocess(self, sources) -> List[List[Dict[str, str]]]:
        pass

    def vis_preprocess(self, vis_path) -> Image:
        pass

    @property
    def data_type(self):
        return self.type

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        item = self.annotation[i]

        vis_path = item['vis_path'] if 'vis_path' in item else item['video_path']

        ret = {
            'images': self.vis_preprocess(vis_path),
            'conversations': self.text_preprocess(item)
        }
        if 'id' in item:
            ret['id'] = item['id']

        return ret


class ImageTaskDataset(TaskBaseDataset):
    def __init__(self, anno_path=None, data_args=None, name=None):
        super().__init__(anno_path=anno_path,
                         data_args=data_args,
                         name=name)
        self.type = 'images'

    @staticmethod
    def expand2square(pil_img, background_color):
        width, height = pil_img.size
        if width == height:
            return pil_img
        elif width > height:
            result = Image.new(pil_img.mode, (width, width), background_color)
            result.paste(pil_img, (0, (width - height) // 2))
            return result
        else:
            result = Image.new(pil_img.mode, (height, height), background_color)
            result.paste(pil_img, ((height - width) // 2, 0))
            return result

    def preprocess_image(self, image):
        if self.image_aspect_ratio == 'pad':
            image = self.expand2square(image, tuple(int(x *255) for x in self.vis_processor.image_mean))
            if isinstance(self.vis_processor, CLIPImageProcessor) or isinstance(self.vis_processor, SiglipImageProcessor):
                image = self.vis_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
            else:
                image = self.vis_processor.preprocess(image)
        elif self.image_aspect_ratio == "anyres":
            image = process_anyres_image(image, self.vis_processor, self.image_grid_pinpoints)
        else:
            if isinstance(self.vis_processor, CLIPImageProcessor) or isinstance(self.vis_processor, SiglipImageProcessor):
                image = self.vis_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
            else:
                image = self.vis_processor.preprocess(image)

        return image

    def vis_preprocess(self, vis_path):
        image = Image.open(vis_path).convert('RGB')
        image = self.preprocess_image(image)
        if isinstance(image, list):
            images = image
        else:
            images = [image]

        return images


class VideoTaskDataset(ImageTaskDataset):
    def __init__(self, anno_path=None, data_args=None, name=None):
        super().__init__(anno_path=anno_path,
                         data_args=data_args,
                         name=name)

        # if not specify num_segments, use default
        self.num_segments = self.data_args.num_segments
        self.sample_strategy = self.data_args.sample_strategy
        self.type = 'video'

    def vis_preprocess(self, vis_path):
        images = None
        try:
            video_reader = VideoReader(vis_path)
            vlen = len(video_reader)
            fps = video_reader.get_avg_fps()
            duration = vlen / float(fps)

            frame_indices = get_frame_indices(self.num_segments, vlen,
                                              sample=self.sample_strategy, input_fps=fps, pad_last=False)
            frames = video_reader.get_batch(frame_indices)
            frames = frames.numpy().astype(np.uint8)
            images = [Image.fromarray(frame).convert('RGB') for frame in frames]
            images = [self.preprocess_image(image) for image in images]
        except Exception as e:
            print(e, vis_path)
            sys.stdout.flush()
            images = None

        # print(f"images: {len(images)}, {images[0].shape}")

        return images


class FramesTaskDataset(ImageTaskDataset):
    def __init__(self, anno_path=None, data_args=None, fps=0.5, name=None):
        super().__init__(anno_path=anno_path,
                         data_args=data_args,
                         name=name)

        # if not specify num_segments, use default
        self.num_segments = self.data_args.num_segments
        # print("self.num_segments:", self.num_segments)
        self.type = 'video'
        self.default_fps = 2.0
        self.fps = fps

    @staticmethod
    def _downsample_frames(frames, interval, keep_first_last=True):
        if keep_first_last:
            first, last, mid = frames[0], frames[-1], frames[1:-1]
            sampled_frames = mid[interval - 1::interval]
            ret = [first] + sampled_frames + [last]

        else:
            # may output empty list, recommend keep first and last frame
            ret = frames[interval - 1::interval]

        return ret

    @staticmethod
    def _sample_frames(frames, num_segments):
        frame_indices = list(range(len(frames)))
        cand_indices  = copy.deepcopy(frame_indices)
        intervals = np.linspace(start=0, stop=len(frame_indices), num=num_segments + 1).astype(int)
        ranges = []

        for idx, interv in enumerate(intervals[:-1]):
            ranges.append((interv, intervals[idx + 1] - 1))

        try:
            frame_indices = [cand_indices[random.choice(range(x[0], x[1]))] for x in ranges]
        except:
            frame_indices = [cand_indices[x[0]] for x in ranges]

        sampled_frames = [frames[indice] for indice in frame_indices]

        return sampled_frames

    def vis_preprocess(self, vis_path):
        image_files = [(os.path.splitext(img)[0], img) for img in os.listdir(vis_path) if not img.startswith('cuttime')]
        if image_files[0][1].endswith('jpeg'):
            # gpt4v public data
            image_files = [(int(x[0].split('_')[-1]), x[1]) for x in image_files]
        else:
            image_files = [(int(x[0]), x[1]) for x in image_files]

        image_files = sorted(image_files, key=lambda img: img[0])

        if self.fps < self.default_fps:
            interval = math.floor(self.default_fps / self.fps)
            image_files = self._downsample_frames(image_files, interval, keep_first_last=True)

        if self.num_segments > 0 and len(image_files) > self.num_segments:
            image_files = self._sample_frames(image_files, self.num_segments)

        images = []
        for image_file in image_files:
            try:
                images.append(Image.open(os.path.join(vis_path, image_file[1])).convert('RGB'))
            except Exception as e:
                continue
        formatted_images = []
        for image in images:
            im = self.preprocess_image(image)
            if isinstance(im, list):
                formatted_images.extend(im)
            else:
                formatted_images.append(im)
        return formatted_images