File size: 2,395 Bytes
bbfa6f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import torch
import random
import json
from pathlib import Path
from llava.datasets.builder import DATASETS

from typing import Dict, Optional, Sequence, List
from llava.datasets.data_cfgs import data_configs
from llava.datasets.base_dataset import FramesTaskDataset
from llava.datasets.prompts import tt_caption_prompt, ocr_prompt
from llava.constants import DEFAULT_VIDEO_TOKEN


class SyntheticOCRDataset(FramesTaskDataset):
    def __init__(self, anno_path, data_args=None, fps=2.0, name='synthetic_ocr'):
        super().__init__(anno_path=anno_path,
                         data_args=data_args,
                         fps=fps,
                         name=name)
        self.default_fps = 0.1


    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        item = self.annotation[i]

        ret = {
            'images': self.vis_preprocess(item['video_path']),
            'conversations': self.text_preprocess(item)
        }
        if 'id' in item:
            ret['id'] = item['id']

        return ret

    def text_preprocess(self, item) -> List[Dict[str, str]]:
        all_convs = []
        if hasattr(self.data_args, 'caption_prompt'):
            cap_prompt = eval(self.data_args.caption_prompt)
        else:
            cap_prompt = tt_caption_prompt
        
        conversations = []
        conversations.extend([
                {
                    'from': 'human',
                    'value': DEFAULT_VIDEO_TOKEN + random.choice(cap_prompt)
                },
                {
                    'from': 'model',
                    'value': item['gpt_caption'] + ' ' + random.choice(ocr_prompt) + ','.join(item['ocr_list'])
                }
        ])
        return conversations


@DATASETS.register_obj
def synthetic_ocr(data_args):
    train_data_path = None
    if 'train_data_path' in data_args.external_args:
        train_data_path = data_args.external_args['train_data_path']
    else:
        train_data_path = data_configs["synthetic_ocr"]['train_data_path']
    return SyntheticOCRDataset(train_data_path, data_args, 2.0)

if __name__ == '__main__':
    with open('/mnt/bn/algo-masp-nas-2/xiangchen/dataset/masp/synthetic_ocr/train_filtered.json') as f:
        data = json.load(f)
    
    for sample in data:
        res = sample['gpt_caption'] + ' ' + random.choice(ocr_prompt) + ','.join(sample['ocr_list'])
        # print(res)