model1 / llava /datasets /textcaps_dataset.py
multitensor's picture
Upload folder using huggingface_hub
bbfa6f6 verified
import datasets
import torch
import re
import os
import json
from llava.datasets.builder import DATASETS
from pathlib import Path
import random
from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from llava.datasets.base_dataset import ImageTaskDataset
from llava.datasets.prompts import cc_sbu_prompt
from llava.constants import DEFAULT_IMAGE_TOKEN
from llava.datasets.data_cfgs import data_configs
from llava.utils import master_print
class TextCapsDataset(ImageTaskDataset):
def __init__(self, anno_path=None, data_args=None, aux_args=None, name='TextCaps'):
with open(anno_path) as f:
self.annotation = json.load(f)['data']
self.dataset_dir = Path(anno_path).parent
super().__init__(anno_path=anno_path,
data_args=data_args,
name=name)
def __len__(self):
return len(self.annotation)
def text_preprocess(self, item) -> List[Dict[str, str]]:
conversations = []
conversations.extend([
{
'from': 'human',
'value': DEFAULT_IMAGE_TOKEN + random.choice(cc_sbu_prompt)
},
{
'from': 'model',
'value': item['caption_str']
}
])
return conversations
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
item = self.annotation[i]
vis_path = self.dataset_dir / item['image_path']
ret = {
'images': self.vis_preprocess(str(vis_path)),
'conversations': self.text_preprocess(item)
}
if 'id' in item:
ret['id'] = item['id']
return ret
@DATASETS.register_obj
def TextCaps(data_args):
data_cfg = data_configs['text_caps']
return TextCapsDataset(data_cfg['train_data_path'], data_args)
if __name__ == '__main__':
# viz_dir = '/mnt/bn/yukunfeng-nasdrive/xiangchen/dataset/OCR-VQA/'
with open('/mnt/bn/yukunfeng-nasdrive/xiangchen/dataset/TextCaps/TextCaps_0.1_train.json') as f:
data = json.load(f)
res = []
for value in data:
# ext=os.path.splitext(value['imageURL'])[1]
# outputFile=os.path.join(viz_dir, 'images/%s%s'%(key,ext))
# q = value['questions']
# a = value['answers']
if len(value['questions']) == 0:
print(1)
res.append(value)