File size: 2,494 Bytes
bbfa6f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import datasets
import torch
import re
import os
import json
from llava.datasets.builder import DATASETS
from pathlib import Path
import random
from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from llava.datasets.base_dataset import ImageTaskDataset
from llava.datasets.prompts import cc_sbu_prompt
from llava.constants import DEFAULT_IMAGE_TOKEN
from llava.datasets.data_cfgs import data_configs
from llava.utils import master_print


class TextCapsDataset(ImageTaskDataset):
    def __init__(self, anno_path=None, data_args=None, aux_args=None, name='TextCaps'):
        with open(anno_path) as f:
            self.annotation = json.load(f)['data']        
        self.dataset_dir = Path(anno_path).parent
        super().__init__(anno_path=anno_path,
                         data_args=data_args,
                         name=name)
        
    def __len__(self):
        return len(self.annotation)
    

    def text_preprocess(self, item) -> List[Dict[str, str]]:
        conversations = []
        conversations.extend([
                {
                    'from': 'human',
                    'value': DEFAULT_IMAGE_TOKEN + random.choice(cc_sbu_prompt)
                },
                {
                    'from': 'model',
                    'value': item['caption_str']
                }
            ])       
                    
        return conversations


    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        item = self.annotation[i]
        vis_path = self.dataset_dir / item['image_path']
        ret = {
            'images': self.vis_preprocess(str(vis_path)),
            'conversations': self.text_preprocess(item)
        }
        if 'id' in item:
            ret['id'] = item['id']

        return ret

@DATASETS.register_obj
def TextCaps(data_args):
    data_cfg = data_configs['text_caps']
    return TextCapsDataset(data_cfg['train_data_path'], data_args)

if __name__ == '__main__':
    # viz_dir = '/mnt/bn/yukunfeng-nasdrive/xiangchen/dataset/OCR-VQA/'
    with open('/mnt/bn/yukunfeng-nasdrive/xiangchen/dataset/TextCaps/TextCaps_0.1_train.json') as f:
        data = json.load(f)
    res = []
    for value in data:
        # ext=os.path.splitext(value['imageURL'])[1]
        # outputFile=os.path.join(viz_dir, 'images/%s%s'%(key,ext))
        # q = value['questions']
        # a = value['answers']
        if len(value['questions']) == 0:
            print(1)
        res.append(value)