File size: 3,668 Bytes
569f484
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import json
import os
import re
from torch.utils.data import Dataset

def prompt_processor(prompt):
    if prompt.startswith('OCR tokens: '):
        pattern = r"Question: (.*?) Short answer:"
        match = re.search(pattern, prompt, re.DOTALL)
        question = match.group(1)
    elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3:
        if prompt.startswith('Reference OCR token:'):
            question = prompt.split('\n')[1]
        else:
            question = prompt.split('\n')[0]
    elif len(prompt.split('\n')) == 2:
        question = prompt.split('\n')[0]
    else:
        assert False

    return question.lower()
    
class textVQADataset(Dataset):
    def __init__(
        self,
        image_dir="./downloads/TextVQA/train_images",
        ann_path="./downloads/TextVQA/TextVQA_0.5.1_val.json",
    ):
        self.data = json.load(open(ann_path, "r"))["data"]
        self.image_dir = image_dir

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        question = self.data[idx]['question']
        answers = self.data[idx]['answers']
        img_id = self.data[idx]['image_id']
        qid = self.data[idx]['question_id']
        img_path = os.path.join(self.image_dir, f"{img_id}.jpg")
        
        item = {
            "question_id": qid,
            "image_path": img_path,
            "question": question,
            "gt_answers": answers
        }
        
        return item
    
class docVQADataset(Dataset):
    def __init__(
        self,
        image_dir= "./downloads/DocVQA/spdocvqa_images",
        ann_path= "./downloads/DocVQA/val_v1.0_withQT.json",
        ocr_token_path=None
    ):

        self.data = json.load(open(ann_path, "r"))["data"]
        self.image_dir = image_dir
        self.ann_path = ann_path
        if ocr_token_path:
            self.ocr_token_data = {item['image_id']: item for item in json.load(open(ocr_token_path, "r"))["data"]}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        question_id = self.data[idx]['questionId']  
        relative_img_path = self.data[idx]['image']
        corrected_relative_img_path = relative_img_path.replace("documents", "images")
        img_path = os.path.join(self.image_dir, corrected_relative_img_path)
        question = self.data[idx]['question']
        answers = self.data[idx]['answers']
        
        question_type = self.data[idx]['question_types']
        
        return {
            "question_id": question_id,  
            "image_path": img_path,
            "question": question,
            "gt_answers": answers,
            'question_type': question_type,
        }


class docVQATESTDataset(Dataset):
    def __init__(
        self,
        image_dir= "./downloads/DocVQA/spdocvqa_images",
        ann_path= "./downloads/DocVQA/test_v1.0.json",
        ocr_token_path=None
    ):

        self.data = json.load(open(ann_path, "r"))["data"]
        self.image_dir = image_dir
        self.ann_path = ann_path

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        question_id = self.data[idx]['questionId']  
        relative_img_path = self.data[idx]['image']
        corrected_relative_img_path = relative_img_path.replace("documents", "images")
        img_path = os.path.join(self.image_dir, corrected_relative_img_path)
        question = self.data[idx]['question']
        
        
        return {
            "question_id": question_id,  
            "image_path": img_path,
            "question": question,
            "gt_answers": "",
            'question_type': "",
        }