realtime-qa / triviaQA.py
mohit-raghavendra's picture
Upload 25 files
3060e5b verified
import datasets
from llm.qa_agent import QnAAgent
validation_dataset = datasets.load_dataset(
"trivia_qa", "rc", split="test"
) # remove [:5%] to run on full validation set
PUNCTUATION_SET_TO_EXCLUDE = set("".join(["β€˜", "’", "Β΄", "`", ".", ",", "-", '"']))
qna_agent = QnAAgent()
def get_sub_answers(answers, begin=0, end=None):
return [" ".join(x.split(" ")[begin:end]) for x in answers if len(x.split(" ")) > 1]
def expand_to_aliases(given_answers, make_sub_answers=False):
if make_sub_answers:
# if answers are longer than one word, make sure a predictions is correct if it coresponds to the complete 1: or :-1 sub word
# *e.g.* if the correct answer contains a prefix such as "the", or "a"
given_answers = (
given_answers
+ get_sub_answers(given_answers, begin=1)
+ get_sub_answers(given_answers, end=-1)
)
answers = []
for answer in given_answers:
alias = answer.replace("_", " ").lower()
alias = "".join(
c if c not in PUNCTUATION_SET_TO_EXCLUDE else " " for c in alias
)
answers.append(" ".join(alias.split()).strip())
return set(answers)
def evaluate(example):
# get answer from QnA agent
answer_without_context = qna_agent.get_answer(example["question"], use_context=False)
answer_with_context = qna_agent.get_answer(example["question"], use_context=True)
example["output"] = answer_without_context
example["output_context"] = answer_with_context
example["targets"] = example["answer"]["aliases"]
answers = expand_to_aliases(example["targets"], make_sub_answers=True)
predictions = expand_to_aliases([example["output"]])
preditions_with_context = expand_to_aliases([example["output_context"]])
# if there is a common element, it's a match
example["match"] = len(list(answers & predictions)) > 0
example["match_context"] = len(list(answers & preditions_with_context)) > 0
return example
results = validation_dataset.map(evaluate)
print("Exact Match (EM) without context: {:.2f}".format(100 * sum(results['match'])/len(results)))
print("Exact Match (EM) with context: {:.2f}".format(100 * sum(results['match_context'])/len(results)))