import datasets import torch import re import os import json from llava.datasets.builder import DATASETS from pathlib import Path import random from typing import Dict, Optional, Sequence, List from llava.datasets.data_cfgs import data_configs from llava.datasets.base_dataset import ImageTaskDataset from llava.datasets.prompts import cc_sbu_prompt from llava.constants import DEFAULT_IMAGE_TOKEN from llava.datasets.data_cfgs import data_configs from llava.utils import master_print class TextCapsDataset(ImageTaskDataset): def __init__(self, anno_path=None, data_args=None, aux_args=None, name='TextCaps'): with open(anno_path) as f: self.annotation = json.load(f)['data'] self.dataset_dir = Path(anno_path).parent super().__init__(anno_path=anno_path, data_args=data_args, name=name) def __len__(self): return len(self.annotation) def text_preprocess(self, item) -> List[Dict[str, str]]: conversations = [] conversations.extend([ { 'from': 'human', 'value': DEFAULT_IMAGE_TOKEN + random.choice(cc_sbu_prompt) }, { 'from': 'model', 'value': item['caption_str'] } ]) return conversations def __getitem__(self, i) -> Dict[str, torch.Tensor]: item = self.annotation[i] vis_path = self.dataset_dir / item['image_path'] ret = { 'images': self.vis_preprocess(str(vis_path)), 'conversations': self.text_preprocess(item) } if 'id' in item: ret['id'] = item['id'] return ret @DATASETS.register_obj def TextCaps(data_args): data_cfg = data_configs['text_caps'] return TextCapsDataset(data_cfg['train_data_path'], data_args) if __name__ == '__main__': # viz_dir = '/mnt/bn/yukunfeng-nasdrive/xiangchen/dataset/OCR-VQA/' with open('/mnt/bn/yukunfeng-nasdrive/xiangchen/dataset/TextCaps/TextCaps_0.1_train.json') as f: data = json.load(f) res = [] for value in data: # ext=os.path.splitext(value['imageURL'])[1] # outputFile=os.path.join(viz_dir, 'images/%s%s'%(key,ext)) # q = value['questions'] # a = value['answers'] if len(value['questions']) == 0: print(1) res.append(value)