File size: 2,494 Bytes
bbfa6f6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import datasets
import torch
import re
import os
import json
from llava.datasets.builder import DATASETS
from pathlib import Path
import random
from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from llava.datasets.base_dataset import ImageTaskDataset
from llava.datasets.prompts import cc_sbu_prompt
from llava.constants import DEFAULT_IMAGE_TOKEN
from llava.datasets.data_cfgs import data_configs
from llava.utils import master_print
class TextCapsDataset(ImageTaskDataset):
def __init__(self, anno_path=None, data_args=None, aux_args=None, name='TextCaps'):
with open(anno_path) as f:
self.annotation = json.load(f)['data']
self.dataset_dir = Path(anno_path).parent
super().__init__(anno_path=anno_path,
data_args=data_args,
name=name)
def __len__(self):
return len(self.annotation)
def text_preprocess(self, item) -> List[Dict[str, str]]:
conversations = []
conversations.extend([
{
'from': 'human',
'value': DEFAULT_IMAGE_TOKEN + random.choice(cc_sbu_prompt)
},
{
'from': 'model',
'value': item['caption_str']
}
])
return conversations
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
item = self.annotation[i]
vis_path = self.dataset_dir / item['image_path']
ret = {
'images': self.vis_preprocess(str(vis_path)),
'conversations': self.text_preprocess(item)
}
if 'id' in item:
ret['id'] = item['id']
return ret
@DATASETS.register_obj
def TextCaps(data_args):
data_cfg = data_configs['text_caps']
return TextCapsDataset(data_cfg['train_data_path'], data_args)
if __name__ == '__main__':
# viz_dir = '/mnt/bn/yukunfeng-nasdrive/xiangchen/dataset/OCR-VQA/'
with open('/mnt/bn/yukunfeng-nasdrive/xiangchen/dataset/TextCaps/TextCaps_0.1_train.json') as f:
data = json.load(f)
res = []
for value in data:
# ext=os.path.splitext(value['imageURL'])[1]
# outputFile=os.path.join(viz_dir, 'images/%s%s'%(key,ext))
# q = value['questions']
# a = value['answers']
if len(value['questions']) == 0:
print(1)
res.append(value)
|