Demo750's picture
Upload folder using huggingface_hub
569f484 verified
raw
history blame
3.67 kB
import json
import os
import re
from torch.utils.data import Dataset
def prompt_processor(prompt):
if prompt.startswith('OCR tokens: '):
pattern = r"Question: (.*?) Short answer:"
match = re.search(pattern, prompt, re.DOTALL)
question = match.group(1)
elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3:
if prompt.startswith('Reference OCR token:'):
question = prompt.split('\n')[1]
else:
question = prompt.split('\n')[0]
elif len(prompt.split('\n')) == 2:
question = prompt.split('\n')[0]
else:
assert False
return question.lower()
class textVQADataset(Dataset):
def __init__(
self,
image_dir="./downloads/TextVQA/train_images",
ann_path="./downloads/TextVQA/TextVQA_0.5.1_val.json",
):
self.data = json.load(open(ann_path, "r"))["data"]
self.image_dir = image_dir
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
question = self.data[idx]['question']
answers = self.data[idx]['answers']
img_id = self.data[idx]['image_id']
qid = self.data[idx]['question_id']
img_path = os.path.join(self.image_dir, f"{img_id}.jpg")
item = {
"question_id": qid,
"image_path": img_path,
"question": question,
"gt_answers": answers
}
return item
class docVQADataset(Dataset):
def __init__(
self,
image_dir= "./downloads/DocVQA/spdocvqa_images",
ann_path= "./downloads/DocVQA/val_v1.0_withQT.json",
ocr_token_path=None
):
self.data = json.load(open(ann_path, "r"))["data"]
self.image_dir = image_dir
self.ann_path = ann_path
if ocr_token_path:
self.ocr_token_data = {item['image_id']: item for item in json.load(open(ocr_token_path, "r"))["data"]}
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
question_id = self.data[idx]['questionId']
relative_img_path = self.data[idx]['image']
corrected_relative_img_path = relative_img_path.replace("documents", "images")
img_path = os.path.join(self.image_dir, corrected_relative_img_path)
question = self.data[idx]['question']
answers = self.data[idx]['answers']
question_type = self.data[idx]['question_types']
return {
"question_id": question_id,
"image_path": img_path,
"question": question,
"gt_answers": answers,
'question_type': question_type,
}
class docVQATESTDataset(Dataset):
def __init__(
self,
image_dir= "./downloads/DocVQA/spdocvqa_images",
ann_path= "./downloads/DocVQA/test_v1.0.json",
ocr_token_path=None
):
self.data = json.load(open(ann_path, "r"))["data"]
self.image_dir = image_dir
self.ann_path = ann_path
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
question_id = self.data[idx]['questionId']
relative_img_path = self.data[idx]['image']
corrected_relative_img_path = relative_img_path.replace("documents", "images")
img_path = os.path.join(self.image_dir, corrected_relative_img_path)
question = self.data[idx]['question']
return {
"question_id": question_id,
"image_path": img_path,
"question": question,
"gt_answers": "",
'question_type': "",
}