File size: 4,750 Bytes
bbfa6f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

from llava.datasets.builder import DATASETS

from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from llava.datasets.base_dataset import FramesTaskDataset
from llava.datasets.data_cfgs import data_configs
import pickle
from pathlib import Path
import random
import numpy as np
from llava.datasets.prompts import tt_caption_prompt, internvid_prompt
from llava.constants import DEFAULT_VIDEO_TOKEN
from PIL import Image
import json
import torch
import os


class LKVideoDataset(FramesTaskDataset):
    def __init__(self, anno_path=None, data_args=None, fps=1.0, conv_type='multi', select_datasets=None, name='lk_video'):
        self.default_fps = 1.0
        self.fps = fps
        self.conv_type = conv_type
        self.select_datasets = select_datasets
        self.annotation = self.get_dataset(anno_path)
        #TODO: support single
        assert self.conv_type in ('multi'), "lk_video conv type must be multi"
        # assert hasattr(self.data_args, 'task_types') ,  "gpt4v_public must have key 'task_types' in yaml config"
        # master_print(f"Finished loading dataset {name} {len(self.annotation)} samples...")
        super().__init__(anno_path=anno_path,
                         data_args=data_args,
                         fps=fps,
                         name=name)
    def __len__(self):
        return len(self.annotation)


    def get_dataset(self, anno_path):
        anno_path = Path(anno_path)
        with anno_path.open('rb') as f:
            data = json.load(f)
        
        if self.select_datasets is not None:
            filtered_data = []
            for sample in data:
                video_path = Path(sample['video'])
                dataset_name = video_path.parent.name   
                if dataset_name in self.select_datasets:
                    filtered_data.append(sample)
            data = filtered_data

        return data


    def text_preprocess(self, item) -> List[Dict[str, str]]:
        return item['conversations']


    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        item = self.annotation[i]

        ret = {
            'images': self.vis_preprocess(item['video']),
            'conversations': self.text_preprocess(item)
        }
        if 'id' in item:
            ret['id'] = item['id']

        return ret


    @staticmethod
    def _sample_frames(frames, num_segments):
        indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype(int)

        frames = [frames[ind] for ind in indices]

        return frames

    def vis_preprocess(self, vis_path):
        image_files = []
        for img_path in os.listdir(vis_path):
            if img_path.endswith('.jpeg'):
                img_idx = int(img_path.split('_')[-1][:-5])
                image_files.append((img_idx, img_path))
        
        image_files = sorted(image_files, key=lambda img: img[0])
        # TODO: addhoc fix,  only 10 frames
        if len(image_files) > 10:
            image_files = self._sample_frames(image_files, 10)
        if self.num_segments > 0 and len(image_files) > self.num_segments:
            image_files = self._sample_frames(image_files, self.num_segments)
        
        images = []
        for image_file in image_files:
            try:
                images.append(Image.open(os.path.join(vis_path, image_file[1])).convert('RGB'))
            except Exception as e:
                continue
        formatted_images = []
        for image in images:
            im = self.preprocess_image(image)
            if isinstance(im, list):
                formatted_images.extend(im)
            else:
                formatted_images.append(im)
        return formatted_images


@DATASETS.register_obj
def lk_video(data_args):
    data_cfg = data_configs['lk_video']
    fps, conv_type = data_args.external_args['fps'], data_args.external_args['conv_type']
    select_datasets = data_args.external_args['select_datasets'] if 'select_datasets' in data_args.external_args else None
    return LKVideoDataset(data_cfg['train_data_path'], data_args, fps, conv_type, select_datasets=select_datasets)


# if __name__ == '__main__':
    # import json
    # from tqdm import tqdm
    # with open('/mnt/bn/liangkeg/data/xiangchen/finetune_all_detail_vidal200k_videollava_images_vid.json') as f:
    #     data = json.load(f)
    # filterd_data = []
    # for item in tqdm(data):
    #     image_path = item['video']
    #     if os.path.exists(image_path):
    #         filterd_data.append(item)
    #     else:
    #         print(image_path)
    # with open('/mnt/bn/liangkeg/data/xiangchen/finetune_all_detail_vidal200k_videollava_images_vid.json', 'w') as f:
    #     json.dump(filterd_data, f)