File size: 1,944 Bytes
bbfa6f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import datasets
import torch
import re
import os
import subprocess
from llava.datasets.builder import DATASETS

from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from llava.datasets.base_dataset import ImageTaskDataset
from llava.constants import DEFAULT_IMAGE_TOKEN
from llava.datasets.data_cfgs import data_configs
from llava.utils import master_print

class LKImageDataset(ImageTaskDataset):
    def __init__(self, anno_path=None, data_args=None, aux_args=None, name='lk_image'):
        super().__init__(anno_path=anno_path,
                         data_args=data_args,
                         name=name)
 
    def __len__(self):
        return len(self.annotation)

    def text_preprocess(self, item) -> List[Dict[str, str]]:
        return item['conversations']


    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        item = self.annotation[i]
        vis_path = item['image']
        ret = {
            'images': self.vis_preprocess(vis_path),
            'conversations': self.text_preprocess(item)
        }
        if 'id' in item:
            ret['id'] = item['id']
        return ret

@DATASETS.register_obj
def lk_image(data_args):
    data_cfg = data_configs['lk_image']
    return LKImageDataset(data_cfg['train_data_path'], data_args, aux_args=data_cfg)

# if __name__ == '__main__':
    # import json
    # from tqdm import tqdm
    # with open('/mnt/bn/liangkeg/data/xiangchen/finetune_all_detail_vidal200k_videollava_images_im.json') as f:
    #     data = json.load(f)
    # filterd_data = []
    # for idx, item in tqdm(enumerate(data)):
    #     image_path = item['image']
    #     if os.path.exists(image_path):
    #         filterd_data.append(item)
    #     else:
    #         print(image_path)
    # with open('/mnt/bn/liangkeg/data/xiangchen/finetune_all_detail_vidal200k_videollava_images_im.json', 'w') as f:
    #     json.dump(filterd_data, f)