Spaces:
Running
Running
import json | |
import os | |
import re | |
from torch.utils.data import Dataset | |
def prompt_processor(prompt): | |
if prompt.startswith('OCR tokens: '): | |
pattern = r"Question: (.*?) Short answer:" | |
match = re.search(pattern, prompt, re.DOTALL) | |
question = match.group(1) | |
elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3: | |
if prompt.startswith('Reference OCR token:'): | |
question = prompt.split('\n')[1] | |
else: | |
question = prompt.split('\n')[0] | |
elif len(prompt.split('\n')) == 2: | |
question = prompt.split('\n')[0] | |
else: | |
assert False | |
return question.lower() | |
class textVQADataset(Dataset): | |
def __init__( | |
self, | |
image_dir="./downloads/TextVQA/train_images", | |
ann_path="./downloads/TextVQA/TextVQA_0.5.1_val.json", | |
): | |
self.data = json.load(open(ann_path, "r"))["data"] | |
self.image_dir = image_dir | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
question = self.data[idx]['question'] | |
answers = self.data[idx]['answers'] | |
img_id = self.data[idx]['image_id'] | |
qid = self.data[idx]['question_id'] | |
img_path = os.path.join(self.image_dir, f"{img_id}.jpg") | |
item = { | |
"question_id": qid, | |
"image_path": img_path, | |
"question": question, | |
"gt_answers": answers | |
} | |
return item | |
class docVQADataset(Dataset): | |
def __init__( | |
self, | |
image_dir= "./downloads/DocVQA/spdocvqa_images", | |
ann_path= "./downloads/DocVQA/val_v1.0_withQT.json", | |
ocr_token_path=None | |
): | |
self.data = json.load(open(ann_path, "r"))["data"] | |
self.image_dir = image_dir | |
self.ann_path = ann_path | |
if ocr_token_path: | |
self.ocr_token_data = {item['image_id']: item for item in json.load(open(ocr_token_path, "r"))["data"]} | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
question_id = self.data[idx]['questionId'] | |
relative_img_path = self.data[idx]['image'] | |
corrected_relative_img_path = relative_img_path.replace("documents", "images") | |
img_path = os.path.join(self.image_dir, corrected_relative_img_path) | |
question = self.data[idx]['question'] | |
answers = self.data[idx]['answers'] | |
question_type = self.data[idx]['question_types'] | |
return { | |
"question_id": question_id, | |
"image_path": img_path, | |
"question": question, | |
"gt_answers": answers, | |
'question_type': question_type, | |
} | |
class docVQATESTDataset(Dataset): | |
def __init__( | |
self, | |
image_dir= "./downloads/DocVQA/spdocvqa_images", | |
ann_path= "./downloads/DocVQA/test_v1.0.json", | |
ocr_token_path=None | |
): | |
self.data = json.load(open(ann_path, "r"))["data"] | |
self.image_dir = image_dir | |
self.ann_path = ann_path | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
question_id = self.data[idx]['questionId'] | |
relative_img_path = self.data[idx]['image'] | |
corrected_relative_img_path = relative_img_path.replace("documents", "images") | |
img_path = os.path.join(self.image_dir, corrected_relative_img_path) | |
question = self.data[idx]['question'] | |
return { | |
"question_id": question_id, | |
"image_path": img_path, | |
"question": question, | |
"gt_answers": "", | |
'question_type': "", | |
} | |