""" |
Set of utilities for Q&A results validation tasks - Retriver passage |
validation and Reader predicted answer validation |
""" |
import collections |
import logging |
import string |
import unicodedata |
from functools import partial |
from multiprocessing import Pool as ProcessPool |
from typing import Tuple, List, Dict |
import regex as re |
from tasks.orqa.unsupervised.tokenizers import SimpleTokenizer |
logger = logging.getLogger(__name__) |
QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits',\ |
'questions_doc_hits']) |
def calculate_matches(all_docs: Dict[object, Tuple[str, str]], |
answers: List[List[str]], closest_docs: List[Tuple[List[object], |
List[float]]], workers_num: int, match_type: str) -> QAMatchStats: |
""" |
Evaluates answers presence in the set of documents. This function is |
supposed to be used with a large collection of documents and results. |
It internally forks multiple sub-processes for evaluation and then |
merges results |
:param all_docs: dictionary of the entire documents database. |
doc_id -> (doc_text, title) |
:param answers: list of answers's list. One list per question |
:param closest_docs: document ids of the top results along with their |
scores |
:param workers_num: amount of parallel threads to process data |
:param match_type: type of answer matching. Refer to has_answer code for |
available options |
:return: matching information tuple. |
top_k_hits - a list where the index is the amount of top documents retrieved |
and the value is the total amount of valid matches across an entire |
dataset. |
questions_doc_hits - more detailed info with answer matches for every |
question and every retrieved document |
""" |
global dpr_all_documents |
dpr_all_documents = all_docs |
tok_opts = {} |
tokenizer = SimpleTokenizer(**tok_opts) |
processes = ProcessPool( |
processes=workers_num, |
) |
logger.info('Matching answers in top docs...') |
get_score_partial = partial(check_answer, match_type=match_type, |
tokenizer=tokenizer) |
questions_answers_docs = zip(answers, closest_docs) |
scores = processes.map(get_score_partial, questions_answers_docs) |
logger.info('Per question validation results len=%d', len(scores)) |
n_docs = len(closest_docs[0][0]) |
top_k_hits = [0] * n_docs |
for question_hits in scores: |
best_hit = next((i for i, x in enumerate(question_hits) if x), None) |
if best_hit is not None: |
top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]] |
return QAMatchStats(top_k_hits, scores) |
def check_answer(questions_answers_docs, tokenizer, match_type) -> List[bool]: |
""" |
Search through all the top docs to see if they have any of the answers. |
""" |
answers, (doc_ids, doc_scores) = questions_answers_docs |
global dpr_all_documents |
hits = [] |
for i, doc_id in enumerate(doc_ids): |
doc = dpr_all_documents[doc_id] |
text = doc[0] |
answer_found = False |
if text is None: |
logger.warning("no doc in db") |
hits.append(False) |
continue |
if has_answer(answers, text, tokenizer, match_type): |
answer_found = True |
hits.append(answer_found) |
return hits |
def has_answer(answers, text, tokenizer, match_type) -> bool: |
""" |
Check if a document contains an answer string. |
If `match_type` is string, token matching is done between the text |
and answer. |
If `match_type` is regex, we search the whole text with the regex. |
""" |
text = _normalize(text) |
if match_type == 'string': |
text = tokenizer.tokenize(text).words(uncased=True) |
for single_answer in answers: |
single_answer = _normalize(single_answer) |
single_answer = tokenizer.tokenize(single_answer) |
single_answer = single_answer.words(uncased=True) |
for i in range(0, len(text) - len(single_answer) + 1): |
if single_answer == text[i: i + len(single_answer)]: |
return True |
elif match_type == 'regex': |
for single_answer in answers: |
single_answer = _normalize(single_answer) |
if regex_match(text, single_answer): |
return True |
return False |
def regex_match(text, pattern): |
"""Test if a regex pattern is contained within a text.""" |
try: |
pattern = re.compile( |
pattern, |
) |
except BaseException: |
return False |
return pattern.search(text) is not None |
def exact_match_score(prediction, ground_truth): |
return _normalize_answer(prediction) == _normalize_answer(ground_truth) |
def _normalize_answer(s): |
def remove_articles(text): |
return re.sub(r'\b(a|an|the)\b', ' ', text) |
def white_space_fix(text): |
return ' '.join(text.split()) |
def remove_punc(text): |
exclude = set(string.punctuation) |
return ''.join(ch for ch in text if ch not in exclude) |
def lower(text): |
return text.lower() |
return white_space_fix(remove_articles(remove_punc(lower(s)))) |
def _normalize(text): |
return unicodedata.normalize('NFD', text) |