Spaces:
Runtime error
Runtime error
import datasets | |
from llm.qa_agent import QnAAgent | |
validation_dataset = datasets.load_dataset( | |
"trivia_qa", "rc", split="test" | |
) # remove [:5%] to run on full validation set | |
PUNCTUATION_SET_TO_EXCLUDE = set("".join(["β", "β", "Β΄", "`", ".", ",", "-", '"'])) | |
qna_agent = QnAAgent() | |
def get_sub_answers(answers, begin=0, end=None): | |
return [" ".join(x.split(" ")[begin:end]) for x in answers if len(x.split(" ")) > 1] | |
def expand_to_aliases(given_answers, make_sub_answers=False): | |
if make_sub_answers: | |
# if answers are longer than one word, make sure a predictions is correct if it coresponds to the complete 1: or :-1 sub word | |
# *e.g.* if the correct answer contains a prefix such as "the", or "a" | |
given_answers = ( | |
given_answers | |
+ get_sub_answers(given_answers, begin=1) | |
+ get_sub_answers(given_answers, end=-1) | |
) | |
answers = [] | |
for answer in given_answers: | |
alias = answer.replace("_", " ").lower() | |
alias = "".join( | |
c if c not in PUNCTUATION_SET_TO_EXCLUDE else " " for c in alias | |
) | |
answers.append(" ".join(alias.split()).strip()) | |
return set(answers) | |
def evaluate(example): | |
# get answer from QnA agent | |
answer_without_context = qna_agent.get_answer( | |
example["question"], use_context=False | |
) | |
answer_with_context = qna_agent.get_answer(example["question"], use_context=True) | |
example["output"] = answer_without_context | |
example["output_context"] = answer_with_context | |
example["targets"] = example["answer"]["aliases"] | |
answers = expand_to_aliases(example["targets"], make_sub_answers=True) | |
predictions = expand_to_aliases([example["output"]]) | |
preditions_with_context = expand_to_aliases([example["output_context"]]) | |
# if there is a common element, it's a match | |
example["match"] = len(list(answers & predictions)) > 0 | |
example["match_context"] = len(list(answers & preditions_with_context)) > 0 | |
return example | |
results = validation_dataset.map(evaluate) | |
print( | |
"Exact Match (EM) without context: {:.2f}".format( | |
100 * sum(results["match"]) / len(results) | |
) | |
) | |
print( | |
"Exact Match (EM) with context: {:.2f}".format( | |
100 * sum(results["match_context"]) / len(results) | |
) | |
) | |