WhiteWolf21's picture
Initialization
be13417
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import logging
import json
import os
import torch
from tqdm import tqdm
from lavis.common.utils import is_convertible_to_int
import lavis.common.dist_utils as dist_utils
from lavis.common.registry import registry
from lavis.common.vqa_tools.vqa import VQA
from lavis.common.vqa_tools.vqa_eval import VQAEval
from lavis.tasks.base_task import BaseTask
@registry.register_task("vqa")
class VQATask(BaseTask):
def __init__(
self,
num_beams,
max_len,
min_len,
evaluate,
num_ans_candidates,
inference_method="rank",
prompt="",
sample_id_key = "",
ques_files=dict(),
anno_files=dict(),
valid_splits=['val']
):
super().__init__()
self.num_beams = num_beams
self.max_len = max_len
self.min_len = min_len
self.evaluate = evaluate
self.inference_method = inference_method
self.num_ans_candidates = num_ans_candidates
self.prompt = prompt
self.answer_list = None
self.ques_files = ques_files
self.anno_files = anno_files
# generalize to non coco data
self.sample_id_key = sample_id_key
self.valid_splits = valid_splits
@classmethod
def setup_task(cls, cfg):
run_cfg = cfg.run_cfg
num_beams = run_cfg.get("num_beams", 3)
max_len = run_cfg.get("max_len", 10)
min_len = run_cfg.get("min_len", 1)
evaluate = run_cfg.get("evaluate", False)
inference_method = run_cfg.get("inference_method", "rank")
num_ans_candidates = run_cfg.get("num_ans_candidates", 128)
prompt = run_cfg.get("prompt", "")
# generalize to non coco data
sample_id_key = run_cfg.get("sample_id_key", "instance_id")
ques_files = run_cfg.get("ques_files", dict())
anno_files = run_cfg.get("anno_files", dict())
valid_splits = run_cfg.get("valid_splits", ["val"])
return cls(
num_beams=num_beams,
max_len=max_len,
min_len=min_len,
evaluate=evaluate,
num_ans_candidates=num_ans_candidates,
inference_method=inference_method,
prompt=prompt,
sample_id_key = sample_id_key,
ques_files=ques_files,
anno_files=anno_files,
valid_splits=valid_splits
)
def build_datasets(self, cfg):
datasets = super().build_datasets(cfg)
# get question file, annotation file and anwser list in COCO format
for ds_name, dataset in datasets.items():
for split in self.valid_splits:
if split not in dataset:
print(f"Split {split} not found in {ds_name}.")
if (
hasattr(dataset[split], "coco_fmt_qust_file")
and dataset[split].coco_fmt_qust_file is not None
):
self.ques_files[split] = dataset[split].coco_fmt_qust_file
self.anno_files[split] = dataset[split].coco_fmt_anno_file
else:
if split not in self.ques_files: # precomputed and passed in task builder
self.ques_files[split] = os.path.join(registry.get_path("cache_root"),f'{ds_name}_gt', f'{ds_name}_{split}_questions.json')
self.anno_files[split] = os.path.join(registry.get_path("cache_root"), f'{ds_name}_gt', f'{ds_name}_{split}_annotations.json')
if dist_utils.get_rank() == 0:
os.makedirs(os.path.join(registry.get_path("cache_root"),f'{ds_name}_gt'), exist_ok=True)
try:
convert_to_coco_gt(dataset, self.ques_files[split], self.anno_files[split], split, self.sample_id_key)
except:
pass # tasks like vizwiz with no gt answer
try:
self.answer_list = dataset[split].answer_list
except AttributeError:
# if answer_list is not provided, then set it to None
pass
if len(self.ques_files) > 0:
assert len(self.ques_files) == len(
self.anno_files
), "Only support one split for evaluation."
return datasets
def valid_step(self, model, samples):
answers = model.predict_answers(
samples=samples,
answer_list=self.answer_list,
inference_method=self.inference_method,
num_beams=self.num_beams,
max_len=self.max_len,
min_len=self.min_len,
num_ans_candidates=self.num_ans_candidates,
prompt=self.prompt,
)
pred_qa_pairs = []
question_id = samples["question_id"]
for answer, ques_id in zip(answers, question_id):
ques_id = int(ques_id.item()) if isinstance(ques_id, torch.Tensor) else ques_id
if ques_id != int and is_convertible_to_int(ques_id):
ques_id = int(ques_id)
pred_qa_pairs.append({"question_id": ques_id, "answer": answer})
return pred_qa_pairs
def after_evaluation(self, val_result, split_name, **kwargs):
result_file = self.save_result(
val_result,
result_dir=registry.get_path("result_dir"),
filename=f"{split_name}_vqa_result",
remove_duplicate="question_id",
)
metrics = self._report_metrics(result_file=result_file, split=split_name)
return metrics
@dist_utils.main_process
def _report_metrics(self, result_file, split):
"""
Use official VQA evaluation script to report metrics.
"""
metrics = {}
if split in self.ques_files and split in self.anno_files:
vqa = VQA(self.anno_files[split], self.ques_files[split])
vqa_result = vqa.loadRes(
resFile=result_file, quesFile=self.ques_files[split]
)
# create vqaEval object by taking vqa and vqaRes
# n is precision of accuracy (number of places after decimal), default is 2
vqa_scorer = VQAEval(vqa, vqa_result, n=2)
logging.info("Start VQA evaluation.")
vqa_scorer.evaluate()
# print accuracies
overall_acc = vqa_scorer.accuracy["overall"]
metrics["agg_metrics"] = overall_acc
logging.info("Overall Accuracy is: %.02f\n" % overall_acc)
logging.info("Per Answer Type Accuracy is the following:")
for ans_type in vqa_scorer.accuracy["perAnswerType"]:
logging.info(
"%s : %.02f"
% (ans_type, vqa_scorer.accuracy["perAnswerType"][ans_type])
)
metrics[ans_type] = vqa_scorer.accuracy["perAnswerType"][ans_type]
with open(
os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
) as f:
f.write(json.dumps(metrics) + "\n")
return metrics
def convert_to_coco_gt(data, outpath_questions, outpath_annotations, split, sample_id_key):
if split not in data:
return
questions_data = {'info':"", 'task_type':"", 'data_type':"", 'license':"", 'data_subtype':"", 'questions':[]}
annotations_data = {'info':"", 'task_type':"", 'data_type':"", 'license':"", 'data_subtype':"", 'annotations':[]}
print("Generating ground truth annotations...")
for ann in tqdm(data[split]):
if ann == None:
continue
# if ann[sample_id_key] not in img_ids:
# continue
ques_id = ann["question_id"]
ques_id = int(ques_id.item()) if isinstance(ques_id, torch.Tensor) else ques_id
if ques_id != int and is_convertible_to_int(ques_id):
ques_id = int(ques_id)
questions_data["questions"].append({"question": ann["text_input"], "image_id": ann[sample_id_key], "question_id": ques_id})
annotations_data["annotations"].append({
"question_type": "" if "question_type" not in ann else ann["question_type"],
"multiple_choice_answer": ann["answers"][0] if isinstance(ann["answers"], list) else ann["answers"],
"answers": [{"answer":ans, "answer_id":i} for i,ans in enumerate(ann["answers"])] if isinstance(ann["answers"], list) else [{"answer":ann["answers"], "answer_id":0}],
"image_id": ann[sample_id_key],
"question_id": ques_id,
"answer_type": "" if "answer_type" not in ann else ann["answer_type"],
})
json.dump(questions_data, open(outpath_questions, 'w'))
print(f"Saved questions data at {outpath_questions}")
json.dump(annotations_data, open(outpath_annotations, 'w'))
print(f"Saved annotation data at {outpath_annotations}")
@registry.register_task("aok_vqa")
class AOKVQATask(VQATask):
def valid_step(self, model, samples):
answers = model.predict_answers(
samples=samples,
answer_list=self.answer_list,
inference_method=self.inference_method,
num_beams=self.num_beams,
max_len=self.max_len,
min_len=self.min_len,
num_ans_candidates=self.num_ans_candidates,
)
pred_qa_pairs = []
question_id = samples["question_id"]
gt_answers = samples["direct_answers"]
for pred_answer, ques_id, gt_answer in zip(answers, question_id, gt_answers):
pred_qa_pairs.append(
{"question_id": ques_id, "pred_ans": pred_answer, "gt_ans": gt_answer}
)
return pred_qa_pairs
@dist_utils.main_process
def _report_metrics(self, result_file, split):
"""
Implementing accuracy computation for AOKVQA, see
https://github.com/allenai/aokvqa/blob/main/evaluation/eval_predictions.py#L45 for details.
"""
# TODO add evaluation for multi-choice
results = json.load(open(result_file, "r"))
acc = []
for res in results:
if res["gt_ans"] is None:
# prepare test results for leaderboard evaluation
self._save_result_leaderboard(results)
return
pred = res["pred_ans"]
gt_ans = res["gt_ans"]
num_match = sum([pred == gt for gt in gt_ans])
vqa_acc = min(1.0, num_match / 3.0)
acc.append(vqa_acc)
accuracy = sum(acc) / len(acc) * 100
metrics = {"agg_metrics": accuracy, "acc": accuracy}
with open(
os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
) as f:
f.write(json.dumps(metrics) + "\n")
logging.info(metrics)
return metrics
@dist_utils.main_process
def _save_result_leaderboard(self, results):
"""
Saving the results in the format required for leaderboard evaluation.
[TODO] add support for multi-choice.
"""
result_leaderboard = dict()
for res in results:
result_leaderboard[res["question_id"]] = {
"direct_answer": res["pred_ans"],
"multiple_choice": "",
}
result_file = registry.get_path("result_dir") + "_leaderboard.json"
with open(result_file, "w") as f:
json.dump(result_leaderboard, f)
logging.info(f"Saved results for leaderboard evaluation at {result_file}")
@registry.register_task("gqa")
class GQATask(VQATask):
def valid_step(self, model, samples):
answers = model.predict_answers(
samples=samples,
answer_list=self.answer_list,
inference_method=self.inference_method,
num_beams=self.num_beams,
max_len=self.max_len,
min_len=self.min_len,
num_ans_candidates=self.num_ans_candidates,
prompt=self.prompt,
)
pred_qa_pairs = []
question_id = samples["question_id"]
gt_answers = samples["answer"]
for answer, ques_id, gt_answer in zip(answers, question_id, gt_answers):
ques_id = int(ques_id.item()) if isinstance(ques_id, torch.Tensor) else ques_id
pred_qa_pairs.append({"question_id": ques_id, "pred_ans": answer, "gt_ans": gt_answer})
return pred_qa_pairs
def build_datasets(self, cfg):
datasets = BaseTask.build_datasets(self,cfg)
# get question file, annotation file and anwser list in COCO format
for ds_name, dataset in datasets.items():
for split in dataset:
if (
hasattr(dataset[split], "coco_fmt_qust_file")
and dataset[split].coco_fmt_qust_file is not None
):
self.ques_files[split] = dataset[split].coco_fmt_qust_file
self.anno_files[split] = dataset[split].coco_fmt_anno_file
if len(self.ques_files) > 0:
assert len(self.ques_files) == len(
self.anno_files
), "Only support one split for evaluation."
return datasets
@dist_utils.main_process
def _report_metrics(self, result_file, split):
"""
TODO: add other evaluation metrics for GQA
"""
results = json.load(open(result_file, "r"))
acc = []
vqa_tool = VQAEval()
for res in results:
if res["gt_ans"] is None:
# prepare test results for leaderboard evaluation
self._save_result_leaderboard(results)
return
gt_ans = res["gt_ans"]
pred = res["pred_ans"]
# if self.inference_method == "generate":
pred = vqa_tool.processPunctuation(pred)
pred = vqa_tool.processDigitArticle(pred)
# added to ensure that the ground truth format of answers is as expected for non-gqa but similar tasks
gt_ans = vqa_tool.processPunctuation(gt_ans)
gt_ans = vqa_tool.processDigitArticle(gt_ans)
vqa_acc = 1 if pred == gt_ans else 0
acc.append(vqa_acc)
accuracy = sum(acc) / len(acc) * 100
metrics = {"agg_metrics": accuracy, "acc": accuracy}
with open(
os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
) as f:
f.write(json.dumps(metrics) + "\n")
logging.info(metrics)
return metrics
@registry.register_task("discrn_qa")
class DisCRNTask(VQATask):
def valid_step(self, model, samples):
answers = model.predict_answers(
samples=samples,
answer_list=self.answer_list,
inference_method=self.inference_method,
num_beams=self.num_beams,
max_len=self.max_len,
min_len=self.min_len,
num_ans_candidates=self.num_ans_candidates,
prompt=self.prompt,
)
if answers == None: # corrupt videos
return []
pred_qa_pairs = []
question_id = samples["question_id"]
gt_answers = samples["answer"]
for answer, ques_id, gt_answer in zip(answers, question_id, gt_answers):
ques_id = int(ques_id.item()) if isinstance(ques_id, torch.Tensor) else ques_id
pred_qa_pairs.append({"question_id": ques_id, "pred_ans": answer, "gt_ans": gt_answer})
return pred_qa_pairs
def build_datasets(self, cfg):
datasets = BaseTask.build_datasets(self, cfg)
return datasets
@dist_utils.main_process
def _report_metrics(self, result_file, split):
results = json.load(open(result_file, "r"))
acc = []
vqa_tool = VQAEval()
for res in results:
gt_ans = res["gt_ans"]
pred = res["pred_ans"]
# gt_ans = [vqa_tool.processPunctuation(g) for g in gt_ans]
# gt_ans = [vqa_tool.processDigitArticle(g) for g in gt_ans]
# if self.inference_method == "generate":
pred = vqa_tool.processPunctuation(pred)
pred = vqa_tool.processDigitArticle(pred)
tokenized_pred = pred.strip().split(" ")
for ans in gt_ans:
if ans in tokenized_pred:
pred = ans
break
vqa_acc = 1 if pred in gt_ans else 0
acc.append(vqa_acc)
accuracy = sum(acc) / len(acc) * 100
metrics = {"agg_metrics": accuracy, "acc": accuracy}
with open(
os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
) as f:
f.write(json.dumps(metrics) + "\n")
logging.info(metrics)
return metrics