import datasets from llm.qa_agent import QnAAgent validation_dataset = datasets.load_dataset( "trivia_qa", "rc", split="test" ) # remove [:5%] to run on full validation set PUNCTUATION_SET_TO_EXCLUDE = set("".join(["‘", "’", "´", "`", ".", ",", "-", '"'])) qna_agent = QnAAgent() def get_sub_answers(answers, begin=0, end=None): return [" ".join(x.split(" ")[begin:end]) for x in answers if len(x.split(" ")) > 1] def expand_to_aliases(given_answers, make_sub_answers=False): if make_sub_answers: # if answers are longer than one word, make sure a predictions is correct if it coresponds to the complete 1: or :-1 sub word # *e.g.* if the correct answer contains a prefix such as "the", or "a" given_answers = ( given_answers + get_sub_answers(given_answers, begin=1) + get_sub_answers(given_answers, end=-1) ) answers = [] for answer in given_answers: alias = answer.replace("_", " ").lower() alias = "".join( c if c not in PUNCTUATION_SET_TO_EXCLUDE else " " for c in alias ) answers.append(" ".join(alias.split()).strip()) return set(answers) def evaluate(example): # get answer from QnA agent answer_without_context = qna_agent.get_answer( example["question"], use_context=False ) answer_with_context = qna_agent.get_answer(example["question"], use_context=True) example["output"] = answer_without_context example["output_context"] = answer_with_context example["targets"] = example["answer"]["aliases"] answers = expand_to_aliases(example["targets"], make_sub_answers=True) predictions = expand_to_aliases([example["output"]]) preditions_with_context = expand_to_aliases([example["output_context"]]) # if there is a common element, it's a match example["match"] = len(list(answers & predictions)) > 0 example["match_context"] = len(list(answers & preditions_with_context)) > 0 return example results = validation_dataset.map(evaluate) print( "Exact Match (EM) without context: {:.2f}".format( 100 * sum(results["match"]) / len(results) ) ) print( "Exact Match (EM) with context: {:.2f}".format( 100 * sum(results["match_context"]) / len(results) ) )