Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
# Copyright 2017-present, Facebook, Inc. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
"""DrQA reader utilities.""" | |
import json | |
import time | |
import logging | |
import string | |
import regex as re | |
from collections import Counter | |
from .data import Dictionary | |
logger = logging.getLogger(__name__) | |
# ------------------------------------------------------------------------------ | |
# Data loading | |
# ------------------------------------------------------------------------------ | |
def load_data(args, filename, skip_no_answer=False): | |
"""Load examples from preprocessed file. | |
One example per line, JSON encoded. | |
""" | |
# Load JSON lines | |
with open(filename) as f: | |
examples = [json.loads(line) for line in f] | |
# Make case insensitive? | |
if args.uncased_question or args.uncased_doc: | |
for ex in examples: | |
if args.uncased_question: | |
ex['question'] = [w.lower() for w in ex['question']] | |
if args.uncased_doc: | |
ex['document'] = [w.lower() for w in ex['document']] | |
# Skip unparsed (start/end) examples | |
if skip_no_answer: | |
examples = [ex for ex in examples if len(ex['answers']) > 0] | |
return examples | |
def load_text(filename): | |
"""Load the paragraphs only of a SQuAD dataset. Store as qid -> text.""" | |
# Load JSON file | |
with open(filename) as f: | |
examples = json.load(f)['data'] | |
texts = {} | |
for article in examples: | |
for paragraph in article['paragraphs']: | |
for qa in paragraph['qas']: | |
texts[qa['id']] = paragraph['context'] | |
return texts | |
def load_answers(filename): | |
"""Load the answers only of a SQuAD dataset. Store as qid -> [answers].""" | |
# Load JSON file | |
with open(filename) as f: | |
examples = json.load(f)['data'] | |
ans = {} | |
for article in examples: | |
for paragraph in article['paragraphs']: | |
for qa in paragraph['qas']: | |
ans[qa['id']] = list(map(lambda x: x['text'], qa['answers'])) | |
return ans | |
# ------------------------------------------------------------------------------ | |
# Dictionary building | |
# ------------------------------------------------------------------------------ | |
def index_embedding_words(embedding_file): | |
"""Put all the words in embedding_file into a set.""" | |
words = set() | |
with open(embedding_file) as f: | |
for line in f: | |
w = Dictionary.normalize(line.rstrip().split(' ')[0]) | |
words.add(w) | |
return words | |
def load_words(args, examples): | |
"""Iterate and index all the words in examples (documents + questions).""" | |
def _insert(iterable): | |
for w in iterable: | |
w = Dictionary.normalize(w) | |
if valid_words and w not in valid_words: | |
continue | |
words.add(w) | |
if args.restrict_vocab and args.embedding_file: | |
logger.info('Restricting to words in %s' % args.embedding_file) | |
valid_words = index_embedding_words(args.embedding_file) | |
logger.info('Num words in set = %d' % len(valid_words)) | |
else: | |
valid_words = None | |
words = set() | |
for ex in examples: | |
_insert(ex['question']) | |
_insert(ex['document']) | |
return words | |
def build_word_dict(args, examples): | |
"""Return a dictionary from question and document words in | |
provided examples. | |
""" | |
word_dict = Dictionary() | |
for w in load_words(args, examples): | |
word_dict.add(w) | |
return word_dict | |
def top_question_words(args, examples, word_dict): | |
"""Count and return the most common question words in provided examples.""" | |
word_count = Counter() | |
for ex in examples: | |
for w in ex['question']: | |
w = Dictionary.normalize(w) | |
if w in word_dict: | |
word_count.update([w]) | |
return word_count.most_common(args.tune_partial) | |
def build_feature_dict(args, examples): | |
"""Index features (one hot) from fields in examples and options.""" | |
def _insert(feature): | |
if feature not in feature_dict: | |
feature_dict[feature] = len(feature_dict) | |
feature_dict = {} | |
# Exact match features | |
if args.use_in_question: | |
_insert('in_question') | |
_insert('in_question_uncased') | |
if args.use_lemma: | |
_insert('in_question_lemma') | |
# Part of speech tag features | |
if args.use_pos: | |
for ex in examples: | |
for w in ex['pos']: | |
_insert('pos=%s' % w) | |
# Named entity tag features | |
if args.use_ner: | |
for ex in examples: | |
for w in ex['ner']: | |
_insert('ner=%s' % w) | |
# Term frequency feature | |
if args.use_tf: | |
_insert('tf') | |
return feature_dict | |
# ------------------------------------------------------------------------------ | |
# Evaluation. Follows official evalutation script for v1.1 of the SQuAD dataset. | |
# ------------------------------------------------------------------------------ | |
def normalize_answer(s): | |
"""Lower text and remove punctuation, articles and extra whitespace.""" | |
def remove_articles(text): | |
return re.sub(r'\b(a|an|the)\b', ' ', text) | |
def white_space_fix(text): | |
return ' '.join(text.split()) | |
def remove_punc(text): | |
exclude = set(string.punctuation) | |
return ''.join(ch for ch in text if ch not in exclude) | |
def lower(text): | |
return text.lower() | |
return white_space_fix(remove_articles(remove_punc(lower(s)))) | |
def f1_score(prediction, ground_truth): | |
"""Compute the geometric mean of precision and recall for answer tokens.""" | |
prediction_tokens = normalize_answer(prediction).split() | |
ground_truth_tokens = normalize_answer(ground_truth).split() | |
common = Counter(prediction_tokens) & Counter(ground_truth_tokens) | |
num_same = sum(common.values()) | |
if num_same == 0: | |
return 0 | |
precision = 1.0 * num_same / len(prediction_tokens) | |
recall = 1.0 * num_same / len(ground_truth_tokens) | |
f1 = (2 * precision * recall) / (precision + recall) | |
return f1 | |
def exact_match_score(prediction, ground_truth): | |
"""Check if the prediction is a (soft) exact match with the ground truth.""" | |
return normalize_answer(prediction) == normalize_answer(ground_truth) | |
def regex_match_score(prediction, pattern): | |
"""Check if the prediction matches the given regular expression.""" | |
try: | |
compiled = re.compile( | |
pattern, | |
flags=re.IGNORECASE + re.UNICODE + re.MULTILINE | |
) | |
except BaseException: | |
logger.warn('Regular expression failed to compile: %s' % pattern) | |
return False | |
return compiled.match(prediction) is not None | |
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): | |
"""Given a prediction and multiple valid answers, return the score of | |
the best prediction-answer_n pair given a metric function. | |
""" | |
scores_for_ground_truths = [] | |
for ground_truth in ground_truths: | |
score = metric_fn(prediction, ground_truth) | |
scores_for_ground_truths.append(score) | |
return max(scores_for_ground_truths) | |
# ------------------------------------------------------------------------------ | |
# Utility classes | |
# ------------------------------------------------------------------------------ | |
class AverageMeter(object): | |
"""Computes and stores the average and current value.""" | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
class Timer(object): | |
"""Computes elapsed time.""" | |
def __init__(self): | |
self.running = True | |
self.total = 0 | |
self.start = time.time() | |
def reset(self): | |
self.running = True | |
self.total = 0 | |
self.start = time.time() | |
return self | |
def resume(self): | |
if not self.running: | |
self.running = True | |
self.start = time.time() | |
return self | |
def stop(self): | |
if self.running: | |
self.running = False | |
self.total += time.time() - self.start | |
return self | |
def time(self): | |
if self.running: | |
return self.total + time.time() - self.start | |
return self.total | |