File size: 3,955 Bytes
bbfa6f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import datasets
import torch
import re
import os
import subprocess
from llava.datasets.builder import DATASETS

from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from llava.datasets.base_dataset import ImageTaskDataset
from llava.constants import DEFAULT_IMAGE_TOKEN
from llava.datasets.data_cfgs import data_configs
from llava.utils import master_print

class ShareGPT4VDataset(ImageTaskDataset):
    def __init__(self, anno_path=None, data_args=None, aux_args=None, name='sharegpt4v'):
        super().__init__(anno_path=anno_path,
                         data_args=data_args,
                         name=name)
        self.annotation = datasets.load_dataset("Lin-Chen/ShareGPT4V", "ShareGPT4V")['train']
        self.aux_args = aux_args
        master_print(f"Finished loading dataset {name} {len(self.annotation)} samples...")

        
    def __len__(self):
        return len(self.annotation)

    def text_preprocess(self, item) -> List[Dict[str, str]]:
        captions = item['conversations']

        conversations = []
        conv = [
            {
                'from': 'human',
                'value': DEFAULT_IMAGE_TOKEN + captions[0]['value'].replace('<image>', '')
            },
            {
                'from': 'model',
                'value': captions[1]['value']
            }
        ]
        conversations.extend(conv)

        return conversations


    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        item = self.annotation[i]
        if 'coco' in item['image']:
            vis_path = os.path.join(self.aux_args['coco_dir'], item['image'])
        elif 'llava' in item['image']:
            file_names =  item['image'].split('/')
            vis_path = os.path.join(self.aux_args['llava_dir'], *file_names[-3:])
        else:
            vis_path = os.path.join(self.aux_args['other_dir'], item['image'])
            
        ret = {
            'images': self.vis_preprocess(vis_path),
            'conversations': self.text_preprocess(item)
        }
        if 'id' in item:
            ret['id'] = item['id']

        return ret

@DATASETS.register_obj
def ShareGPT4V(data_args):
    data_cfg = data_configs['sharegpt4v']
    return ShareGPT4VDataset(None, data_args, aux_args=data_cfg)

if __name__ == '__main__':
    dataset = datasets.load_dataset("Lin-Chen/ShareGPT4V", "ShareGPT4V")['train']
    aux_args = data_configs['sharegpt4v']
    for item in dataset:
        if 'coco' in item['image']:
            vis_path = os.path.join(aux_args['coco_dir'], item['image'])
        elif 'llava' in item['image']:
            file_names =  item['image'].split('/')
            vis_path = os.path.join(aux_args['llava_dir'], *file_names[-3:])
        else:
            vis_path = os.path.join(aux_args['other_dir'], item['image'])
        if not os.path.exists(vis_path):
            print(vis_path)
    # with open('/mnt/bn/yukunfeng-nasdrive/xiangchen/dataset/sharegpt4v/sam.txt') as f:
    #     for line in f:
    #         items = line.split('\t')
    #         name = items[0].strip()
    #         url = items[1].strip()
    #         match = re.search(r'(\d+)', name).group(1)
    #         idx = int(match)
    #         if idx >= 60:
    #             continue
    #         print(name, url)
    #         output_file = os.path.join('/mnt/bn/yukunfeng-nasdrive/xiangchen/dataset/sharegpt4v/sam', name)
    #         try:
    #             subprocess.run(["wget", "-O", output_file, url], check=True)
    #         except subprocess.CalledProcessError as e:
    #             print("An error occurred while downloading the file.")
    # from glob import glob
    # file_path = '/mnt/bn/yukunfeng-nasdrive/xiangchen/dataset/sharegpt4v/sam'
    # for file_name in glob(os.path.join(file_path, '*.tar')):
    #     subprocess.run(["tar", "-xf", file_name, '-C', '/mnt/bn/yukunfeng-nasdrive/xiangchen/dataset/sharegpt4v/sam/images'], check=True)