model1 / llava /datasets /synthetic_ocr_dataset.py
multitensor's picture
Upload folder using huggingface_hub
bbfa6f6 verified
import os
import torch
import random
import json
from pathlib import Path
from llava.datasets.builder import DATASETS
from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from llava.datasets.base_dataset import FramesTaskDataset
from llava.datasets.prompts import tt_caption_prompt, ocr_prompt
from llava.constants import DEFAULT_VIDEO_TOKEN
class SyntheticOCRDataset(FramesTaskDataset):
def __init__(self, anno_path, data_args=None, fps=2.0, name='synthetic_ocr'):
super().__init__(anno_path=anno_path,
data_args=data_args,
fps=fps,
name=name)
self.default_fps = 0.1
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
item = self.annotation[i]
ret = {
'images': self.vis_preprocess(item['video_path']),
'conversations': self.text_preprocess(item)
}
if 'id' in item:
ret['id'] = item['id']
return ret
def text_preprocess(self, item) -> List[Dict[str, str]]:
all_convs = []
if hasattr(self.data_args, 'caption_prompt'):
cap_prompt = eval(self.data_args.caption_prompt)
else:
cap_prompt = tt_caption_prompt
conversations = []
conversations.extend([
{
'from': 'human',
'value': DEFAULT_VIDEO_TOKEN + random.choice(cap_prompt)
},
{
'from': 'model',
'value': item['gpt_caption'] + ' ' + random.choice(ocr_prompt) + ','.join(item['ocr_list'])
}
])
return conversations
@DATASETS.register_obj
def synthetic_ocr(data_args):
train_data_path = None
if 'train_data_path' in data_args.external_args:
train_data_path = data_args.external_args['train_data_path']
else:
train_data_path = data_configs["synthetic_ocr"]['train_data_path']
return SyntheticOCRDataset(train_data_path, data_args, 2.0)
if __name__ == '__main__':
with open('/mnt/bn/algo-masp-nas-2/xiangchen/dataset/masp/synthetic_ocr/train_filtered.json') as f:
data = json.load(f)
for sample in data:
res = sample['gpt_caption'] + ' ' + random.choice(ocr_prompt) + ','.join(sample['ocr_list'])
# print(res)