Demo750's picture
Upload folder using huggingface_hub
569f484 verified
raw
history blame
14.8 kB
import itertools
import json
import os
import re
from collections import namedtuple
import torch
from tqdm import tqdm
class InferenceSampler(torch.utils.data.sampler.Sampler):
def __init__(self, size):
self._size = int(size)
assert size > 0
self._rank = torch.distributed.get_rank()
self._world_size = torch.distributed.get_world_size()
self._local_indices = self._get_local_indices(size, self._world_size,
self._rank)
@staticmethod
def _get_local_indices(total_size, world_size, rank):
shard_size = total_size // world_size
left = total_size % world_size
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
begin = sum(shard_sizes[:rank])
end = min(sum(shard_sizes[:rank + 1]), total_size)
return range(begin, end)
def __iter__(self):
yield from self._local_indices
def __len__(self):
return len(self._local_indices)
def collate_fn_vqa(batches):
'''
'''
image_paths = [_['image_path'] for _ in batches]
questions = [_['question'] for _ in batches]
gt_answers = [_['gt_answers'] for _ in batches]
ocr_tokens = [_['ocr_tokens'] if 'ocr_tokens' in _ else None for _ in batches]
question_ids = [_['question_id'] if 'question_id' in _ else None for _ in batches]
question_type = [_['question_type'] if 'question_type' in _ else None for _ in batches]
return image_paths, questions, gt_answers, ocr_tokens, question_ids, question_type
def has_word(sentence, word):
if word[0].isalnum():
start_pattern = r"\b"
else:
start_pattern = r""
if word[-1].isalnum():
end_pattern = r"\b"
else:
end_pattern = r""
pattern = start_pattern + re.escape(word) + end_pattern
match = re.search(pattern, sentence)
return bool(match)
def remove_special_chars(s):
pattern = r"[^a-zA-Z0-9\s]"
s = re.sub(pattern, "", s)
return s
def levenshtein_distance(s1, s2):
if len(s1) > len(s2):
s1, s2 = s2, s1
distances = range(len(s1) + 1)
for i2, c2 in enumerate(s2):
distances_ = [i2+1]
for i1, c1 in enumerate(s1):
if c1 == c2:
distances_.append(distances[i1])
else:
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
distances = distances_
return distances[-1]
class VQAEval:
def __init__(self):
self.contractions = {
"aint": "ain't",
"arent": "aren't",
"cant": "can't",
"couldve": "could've",
"couldnt": "couldn't",
"couldn'tve": "couldn't've",
"couldnt've": "couldn't've",
"didnt": "didn't",
"doesnt": "doesn't",
"dont": "don't",
"hadnt": "hadn't",
"hadnt've": "hadn't've",
"hadn'tve": "hadn't've",
"hasnt": "hasn't",
"havent": "haven't",
"hed": "he'd",
"hed've": "he'd've",
"he'dve": "he'd've",
"hes": "he's",
"howd": "how'd",
"howll": "how'll",
"hows": "how's",
"Id've": "I'd've",
"I'dve": "I'd've",
"Im": "I'm",
"Ive": "I've",
"isnt": "isn't",
"itd": "it'd",
"itd've": "it'd've",
"it'dve": "it'd've",
"itll": "it'll",
"let's": "let's",
"maam": "ma'am",
"mightnt": "mightn't",
"mightnt've": "mightn't've",
"mightn'tve": "mightn't've",
"mightve": "might've",
"mustnt": "mustn't",
"mustve": "must've",
"neednt": "needn't",
"notve": "not've",
"oclock": "o'clock",
"oughtnt": "oughtn't",
"ow's'at": "'ow's'at",
"'ows'at": "'ow's'at",
"'ow'sat": "'ow's'at",
"shant": "shan't",
"shed've": "she'd've",
"she'dve": "she'd've",
"she's": "she's",
"shouldve": "should've",
"shouldnt": "shouldn't",
"shouldnt've": "shouldn't've",
"shouldn'tve": "shouldn't've",
"somebody'd": "somebodyd",
"somebodyd've": "somebody'd've",
"somebody'dve": "somebody'd've",
"somebodyll": "somebody'll",
"somebodys": "somebody's",
"someoned": "someone'd",
"someoned've": "someone'd've",
"someone'dve": "someone'd've",
"someonell": "someone'll",
"someones": "someone's",
"somethingd": "something'd",
"somethingd've": "something'd've",
"something'dve": "something'd've",
"somethingll": "something'll",
"thats": "that's",
"thered": "there'd",
"thered've": "there'd've",
"there'dve": "there'd've",
"therere": "there're",
"theres": "there's",
"theyd": "they'd",
"theyd've": "they'd've",
"they'dve": "they'd've",
"theyll": "they'll",
"theyre": "they're",
"theyve": "they've",
"twas": "'twas",
"wasnt": "wasn't",
"wed've": "we'd've",
"we'dve": "we'd've",
"weve": "we've",
"werent": "weren't",
"whatll": "what'll",
"whatre": "what're",
"whats": "what's",
"whatve": "what've",
"whens": "when's",
"whered": "where'd",
"wheres": "where's",
"whereve": "where've",
"whod": "who'd",
"whod've": "who'd've",
"who'dve": "who'd've",
"wholl": "who'll",
"whos": "who's",
"whove": "who've",
"whyll": "why'll",
"whyre": "why're",
"whys": "why's",
"wont": "won't",
"wouldve": "would've",
"wouldnt": "wouldn't",
"wouldnt've": "wouldn't've",
"wouldn'tve": "wouldn't've",
"yall": "y'all",
"yall'll": "y'all'll",
"y'allll": "y'all'll",
"yall'd've": "y'all'd've",
"y'alld've": "y'all'd've",
"y'all'dve": "y'all'd've",
"youd": "you'd",
"youd've": "you'd've",
"you'dve": "you'd've",
"youll": "you'll",
"youre": "you're",
"youve": "you've",
}
self.manualMap = {
"none": "0",
"zero": "0",
"one": "1",
"two": "2",
"three": "3",
"four": "4",
"five": "5",
"six": "6",
"seven": "7",
"eight": "8",
"nine": "9",
"ten": "10",
}
self.articles = ["a", "an", "the"]
self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
self.commaStrip = re.compile("(\d)(\,)(\d)")
self.punct = [
";",
r"/",
"[",
"]",
'"',
"{",
"}",
"(",
")",
"=",
"+",
"\\",
"_",
"-",
">",
"<",
"@",
"`",
",",
"?",
"!",
]
def clean_text(self, text):
text = text.replace("\n", " ").replace("\t", " ").strip()
text = self.processPunctuation(text)
text = self.processDigitArticle(text)
return text
def evaluate_vqa_human(self, answer, gt_answers):
'''TextVQA, VQAv2, OKVQA, vizwiz'''
answer = answer.replace("\n", " ").replace("\t", " ").strip()
answer = self.processPunctuation(answer)
answer = self.processDigitArticle(answer)
gt_answers = [self.processPunctuation(ans) for ans in gt_answers]
gt_answers = [self.processDigitArticle(ans) for ans in gt_answers]
gtAcc = []
for idx, gtAnsDatum in enumerate(gt_answers):
otherGTAns = gt_answers[:idx] + gt_answers[idx+1:]
matchingAns = [item for item in otherGTAns if answer == item]
acc = min(1, float(len(matchingAns)) / 3)
gtAcc.append(acc)
avgGTAcc = float(sum(gtAcc)) / len(gtAcc) if gtAcc else 0
return avgGTAcc
def evaluate_anls(self, answer, gt_answers, threshold=0.5):
'''DOcVQA, InfographicsVQA, STVQA'''
answer = ' '.join(answer.strip().lower().split())
if not isinstance(gt_answers, list):
gt_answers = [gt_answers]
gt_answers = [' '.join(gt_answer.strip().lower().split()) for gt_answer in gt_answers]
values = []
for gt_answer in gt_answers:
dist = levenshtein_distance(answer, gt_answer)
length = max(len(answer), len(gt_answer))
values.append(0.0 if length == 0 else float(dist) / float(length))
score = 1 - min(values)
score = 0 if score < threshold else score
return score
def processPunctuation(self, inText):
outText = inText
for p in self.punct:
if (p + " " in inText or " " + p in inText) or (
re.search(self.commaStrip, inText) != None
):
outText = outText.replace(p, "")
else:
outText = outText.replace(p, " ")
outText = self.periodStrip.sub("", outText, re.UNICODE)
return outText
def processDigitArticle(self, inText):
outText = []
tempText = inText.lower().split()
for word in tempText:
word = self.manualMap.setdefault(word, word)
if word not in self.articles:
outText.append(word)
else:
pass
for wordId, word in enumerate(outText):
if word in self.contractions:
outText[wordId] = self.contractions[word]
outText = " ".join(outText)
return outText
def evaluate_dataset(dataset_name, answer_file_path, model_name, method = None):
with open(answer_file_path, 'r', encoding='utf-8') as f:
predictions = json.load(f)
eval = VQAEval()
total_accuracy = 0
num = 0
Entry = namedtuple('Entry', ['text', 'bbox'])
for item in predictions:
gt_answers = item['gt_answers']
answer = item['answer']
if method is not None:
pass
if dataset_name in ["textVQA"]:
if num == 0:
print(f"evaluating vqa...")
accuracy = eval.evaluate_vqa_human(answer, gt_answers)
elif dataset_name in ['docVQA']:
if num == 0:
print(f"evaluating anls...")
accuracy = eval.evaluate_anls(answer, gt_answers)
else:
accuracy = eval.evaluate_has(answer, gt_answers)
item['accuracy'] = accuracy
total_accuracy += accuracy
num += 1
average_accuracy = total_accuracy / num
print(f'{dataset_name}:{average_accuracy}')
answer_model_method_path = answer_file_path.replace('.json', f'_{model_name}_{method}.json')
with open(answer_model_method_path, "w", encoding='utf-8') as f:
json.dump(predictions, f, indent=4, ensure_ascii=False)
return average_accuracy
def evaluate_VQA(
model,
dataset,
model_name,
dataset_name,
time,
batch_size=1,
generate_method="interleave",
answer_path='./answers',
):
print(f"answer path:{answer_path}")
sampler = None
if torch.distributed.is_initialized():
sampler=InferenceSampler(len(dataset))
dataloader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=batch_size,
sampler=sampler,
collate_fn=collate_fn_vqa
)
now_rank = torch.distributed.get_rank()
answer_dir = os.path.join(answer_path, model_name, time)
os.makedirs(answer_dir, exist_ok=True)
image_list = []
for item in dataset:
image_list.append(item["image_path"])
predictions = []
for batch in tqdm(dataloader, desc="Running inference"):
image_paths, questions, gt_answers, ocr_tokens_list, question_ids, question_type = batch
with torch.no_grad():
if model_name != "minicpm":
if model_name != "codellama":
outputs = model.generate(images=image_paths, questions=questions, datasetname=dataset_name)
else:
outputs = model.generate()
elif model_name == "minicpm":
if generate_method == "old":
outputs = model.generate(images=image_paths, questions=questions, datasetname=dataset_name)
elif generate_method == "interleave":
outputs = model.generate_with_interleaved(images=image_paths, questions=questions, datasetname=dataset_name)
else:
raise Exception(f"Wrong generate paradigm {generate_method}!")
for i in range(len(outputs)):
answer_dict = {
'question_id': question_ids[i],
'question': questions[i],
'answer': outputs[i],
'gt_answers': gt_answers[i],
'image_path': image_paths[i],
'model_name': model_name,
'question_type': question_type[i]
}
predictions.append(answer_dict)
if torch.distributed.is_initialized():
torch.distributed.barrier()
if torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
merged_predictions = [None for _ in range(world_size)]
torch.distributed.all_gather_object(merged_predictions, predictions)
predictions = [_ for _ in itertools.chain.from_iterable(merged_predictions)]
if torch.distributed.is_initialized() and torch.distributed.get_rank() != 0:
return None
answer_file_path = os.path.join(answer_dir, f"{dataset_name}.json")
print(f"answer_file_path:{answer_file_path}")
with open(answer_file_path, "w", encoding='utf-8') as f:
json.dump(predictions, f, indent=4, ensure_ascii=False)
if dataset_name in ["docVQATest"]:
return -1.0
return evaluate_dataset(answer_file_path=answer_file_path, dataset_name=dataset_name, model_name=model_name)