import os import torch import random import json from pathlib import Path from llava.datasets.builder import DATASETS from typing import Dict, Optional, Sequence, List from llava.datasets.data_cfgs import data_configs from llava.datasets.base_dataset import FramesTaskDataset from llava.datasets.prompts import tt_caption_prompt, ocr_prompt from llava.constants import DEFAULT_VIDEO_TOKEN class SyntheticOCRDataset(FramesTaskDataset): def __init__(self, anno_path, data_args=None, fps=2.0, name='synthetic_ocr'): super().__init__(anno_path=anno_path, data_args=data_args, fps=fps, name=name) self.default_fps = 0.1 def __getitem__(self, i) -> Dict[str, torch.Tensor]: item = self.annotation[i] ret = { 'images': self.vis_preprocess(item['video_path']), 'conversations': self.text_preprocess(item) } if 'id' in item: ret['id'] = item['id'] return ret def text_preprocess(self, item) -> List[Dict[str, str]]: all_convs = [] if hasattr(self.data_args, 'caption_prompt'): cap_prompt = eval(self.data_args.caption_prompt) else: cap_prompt = tt_caption_prompt conversations = [] conversations.extend([ { 'from': 'human', 'value': DEFAULT_VIDEO_TOKEN + random.choice(cap_prompt) }, { 'from': 'model', 'value': item['gpt_caption'] + ' ' + random.choice(ocr_prompt) + ','.join(item['ocr_list']) } ]) return conversations @DATASETS.register_obj def synthetic_ocr(data_args): train_data_path = None if 'train_data_path' in data_args.external_args: train_data_path = data_args.external_args['train_data_path'] else: train_data_path = data_configs["synthetic_ocr"]['train_data_path'] return SyntheticOCRDataset(train_data_path, data_args, 2.0) if __name__ == '__main__': with open('/mnt/bn/algo-masp-nas-2/xiangchen/dataset/masp/synthetic_ocr/train_filtered.json') as f: data = json.load(f) for sample in data: res = sample['gpt_caption'] + ' ' + random.choice(ocr_prompt) + ','.join(sample['ocr_list']) # print(res)