|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Set of utilities for Q&A results validation tasks - Retriver passage |
|
validation and Reader predicted answer validation |
|
""" |
|
|
|
import collections |
|
import logging |
|
import string |
|
import unicodedata |
|
from functools import partial |
|
from multiprocessing import Pool as ProcessPool |
|
from typing import Tuple, List, Dict |
|
|
|
import regex as re |
|
from tasks.orqa.unsupervised.tokenizers import SimpleTokenizer |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits',\ |
|
'questions_doc_hits']) |
|
|
|
def calculate_matches(all_docs: Dict[object, Tuple[str, str]], |
|
answers: List[List[str]], closest_docs: List[Tuple[List[object], |
|
List[float]]], workers_num: int, match_type: str) -> QAMatchStats: |
|
""" |
|
Evaluates answers presence in the set of documents. This function is |
|
supposed to be used with a large collection of documents and results. |
|
It internally forks multiple sub-processes for evaluation and then |
|
merges results |
|
:param all_docs: dictionary of the entire documents database. |
|
doc_id -> (doc_text, title) |
|
:param answers: list of answers's list. One list per question |
|
:param closest_docs: document ids of the top results along with their |
|
scores |
|
:param workers_num: amount of parallel threads to process data |
|
:param match_type: type of answer matching. Refer to has_answer code for |
|
available options |
|
:return: matching information tuple. |
|
top_k_hits - a list where the index is the amount of top documents retrieved |
|
and the value is the total amount of valid matches across an entire |
|
dataset. |
|
questions_doc_hits - more detailed info with answer matches for every |
|
question and every retrieved document |
|
""" |
|
global dpr_all_documents |
|
dpr_all_documents = all_docs |
|
|
|
tok_opts = {} |
|
tokenizer = SimpleTokenizer(**tok_opts) |
|
|
|
processes = ProcessPool( |
|
processes=workers_num, |
|
) |
|
|
|
logger.info('Matching answers in top docs...') |
|
|
|
get_score_partial = partial(check_answer, match_type=match_type, |
|
tokenizer=tokenizer) |
|
|
|
questions_answers_docs = zip(answers, closest_docs) |
|
|
|
scores = processes.map(get_score_partial, questions_answers_docs) |
|
|
|
logger.info('Per question validation results len=%d', len(scores)) |
|
|
|
n_docs = len(closest_docs[0][0]) |
|
top_k_hits = [0] * n_docs |
|
for question_hits in scores: |
|
best_hit = next((i for i, x in enumerate(question_hits) if x), None) |
|
if best_hit is not None: |
|
top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]] |
|
|
|
return QAMatchStats(top_k_hits, scores) |
|
|
|
|
|
def check_answer(questions_answers_docs, tokenizer, match_type) -> List[bool]: |
|
""" |
|
Search through all the top docs to see if they have any of the answers. |
|
""" |
|
answers, (doc_ids, doc_scores) = questions_answers_docs |
|
|
|
global dpr_all_documents |
|
hits = [] |
|
|
|
for i, doc_id in enumerate(doc_ids): |
|
doc = dpr_all_documents[doc_id] |
|
text = doc[0] |
|
|
|
answer_found = False |
|
if text is None: |
|
logger.warning("no doc in db") |
|
hits.append(False) |
|
continue |
|
|
|
if has_answer(answers, text, tokenizer, match_type): |
|
answer_found = True |
|
hits.append(answer_found) |
|
return hits |
|
|
|
|
|
def has_answer(answers, text, tokenizer, match_type) -> bool: |
|
""" |
|
Check if a document contains an answer string. |
|
If `match_type` is string, token matching is done between the text |
|
and answer. |
|
If `match_type` is regex, we search the whole text with the regex. |
|
""" |
|
text = _normalize(text) |
|
|
|
if match_type == 'string': |
|
|
|
text = tokenizer.tokenize(text).words(uncased=True) |
|
|
|
for single_answer in answers: |
|
single_answer = _normalize(single_answer) |
|
single_answer = tokenizer.tokenize(single_answer) |
|
single_answer = single_answer.words(uncased=True) |
|
|
|
for i in range(0, len(text) - len(single_answer) + 1): |
|
if single_answer == text[i: i + len(single_answer)]: |
|
return True |
|
|
|
elif match_type == 'regex': |
|
|
|
for single_answer in answers: |
|
single_answer = _normalize(single_answer) |
|
if regex_match(text, single_answer): |
|
return True |
|
return False |
|
|
|
|
|
def regex_match(text, pattern): |
|
"""Test if a regex pattern is contained within a text.""" |
|
try: |
|
pattern = re.compile( |
|
pattern, |
|
flags=re.IGNORECASE + re.UNICODE + re.MULTILINE, |
|
) |
|
except BaseException: |
|
return False |
|
return pattern.search(text) is not None |
|
|
|
|
|
|
|
def exact_match_score(prediction, ground_truth): |
|
return _normalize_answer(prediction) == _normalize_answer(ground_truth) |
|
|
|
|
|
def _normalize_answer(s): |
|
def remove_articles(text): |
|
return re.sub(r'\b(a|an|the)\b', ' ', text) |
|
|
|
def white_space_fix(text): |
|
return ' '.join(text.split()) |
|
|
|
def remove_punc(text): |
|
exclude = set(string.punctuation) |
|
return ''.join(ch for ch in text if ch not in exclude) |
|
|
|
def lower(text): |
|
return text.lower() |
|
|
|
return white_space_fix(remove_articles(remove_punc(lower(s)))) |
|
|
|
|
|
def _normalize(text): |
|
return unicodedata.normalize('NFD', text) |
|
|