|
import os |
|
import torch |
|
import random |
|
import json |
|
from pathlib import Path |
|
from llava.datasets.builder import DATASETS |
|
|
|
from typing import Dict, Optional, Sequence, List |
|
from llava.datasets.data_cfgs import data_configs |
|
from llava.datasets.base_dataset import FramesTaskDataset |
|
from llava.datasets.prompts import tt_caption_prompt, ocr_prompt |
|
from llava.constants import DEFAULT_VIDEO_TOKEN |
|
|
|
|
|
class SyntheticOCRDataset(FramesTaskDataset): |
|
def __init__(self, anno_path, data_args=None, fps=2.0, name='synthetic_ocr'): |
|
super().__init__(anno_path=anno_path, |
|
data_args=data_args, |
|
fps=fps, |
|
name=name) |
|
self.default_fps = 0.1 |
|
|
|
|
|
def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
|
item = self.annotation[i] |
|
|
|
ret = { |
|
'images': self.vis_preprocess(item['video_path']), |
|
'conversations': self.text_preprocess(item) |
|
} |
|
if 'id' in item: |
|
ret['id'] = item['id'] |
|
|
|
return ret |
|
|
|
def text_preprocess(self, item) -> List[Dict[str, str]]: |
|
all_convs = [] |
|
if hasattr(self.data_args, 'caption_prompt'): |
|
cap_prompt = eval(self.data_args.caption_prompt) |
|
else: |
|
cap_prompt = tt_caption_prompt |
|
|
|
conversations = [] |
|
conversations.extend([ |
|
{ |
|
'from': 'human', |
|
'value': DEFAULT_VIDEO_TOKEN + random.choice(cap_prompt) |
|
}, |
|
{ |
|
'from': 'model', |
|
'value': item['gpt_caption'] + ' ' + random.choice(ocr_prompt) + ','.join(item['ocr_list']) |
|
} |
|
]) |
|
return conversations |
|
|
|
|
|
@DATASETS.register_obj |
|
def synthetic_ocr(data_args): |
|
train_data_path = None |
|
if 'train_data_path' in data_args.external_args: |
|
train_data_path = data_args.external_args['train_data_path'] |
|
else: |
|
train_data_path = data_configs["synthetic_ocr"]['train_data_path'] |
|
return SyntheticOCRDataset(train_data_path, data_args, 2.0) |
|
|
|
if __name__ == '__main__': |
|
with open('/mnt/bn/algo-masp-nas-2/xiangchen/dataset/masp/synthetic_ocr/train_filtered.json') as f: |
|
data = json.load(f) |
|
|
|
for sample in data: |
|
res = sample['gpt_caption'] + ' ' + random.choice(ocr_prompt) + ','.join(sample['ocr_list']) |
|
|