musique / musique.py
bdsaglam's picture
add musique metric
898eb24
raw
history blame
5.85 kB
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TODO: Add a description here."""
import re
import string
import collections
from typing import Callable
import evaluate
import datasets
# TODO: Add BibTeX citation
_CITATION = """\
@InProceedings{huggingface:module,
title = {A great new module},
authors={huggingface, Inc.},
year={2020}
}
"""
_DESCRIPTION = """\
Question-answering metrics (`Exact Match` and `F1`) for Musique-Answerable dataset.
The implementation is taken from Musique repository.
https://github.com/StonyBrookNLP/musique
"""
_KWARGS_DESCRIPTION = """
Calculates how good are predictions given some references, using certain scores
Args:
predictions: list of predicted answers.
references: list of ground truth answers. Each reference should be a list of
ground truth answers for the corresponding prediction.
Returns:
exact_match: Exact match score,
f1: F1 score over tokens
Examples:
>>> my_new_module = evaluate.load("musique")
>>> results = my_new_module.compute(
references=[["New York City", "NYC"], ["Einstein", "Albert Einstein"]],
predictions=["New York City", "Albert Einstein"],
)
>>> print(results)
{'exact_match': 1.0, 'f1': 1.0}
"""
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class musique(evaluate.Metric):
"""TODO: Question answering metrics (EM and F1) for Musique-Answerable dataset."""
def _info(self):
# TODO: Specifies the evaluate.EvaluationModuleInfo object
return evaluate.MetricInfo(
# This is the description that will appear on the modules page.
module_type="metric",
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
# This defines the format of each prediction and reference
features=datasets.Features(
{
"predictions": datasets.features.Sequence(datasets.Value("string")),
"references": datasets.features.Sequence(
datasets.features.Sequence(datasets.Value("string"))
),
}
),
# Homepage of the module for documentation
homepage="http://module.homepage",
# Additional links to the codebase or references
codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
reference_urls=["http://path.to.reference.url/new_module"],
)
def _download_and_prepare(self, dl_manager):
"""Optional: download external resources useful to compute the scores"""
pass
def _compute(self, predictions, references):
"""Returns the scores"""
if len(predictions) != len(references):
raise ValueError(
"The number of predictions and references should be the same."
)
if len(predictions) == 0:
return {"exact_match": 0.0, "f1": 0.0}
exact_scores = [
metric_max_over_ground_truths(compute_exact, prediction, reference)
for prediction, reference in zip(predictions, references)
]
f1_scores = [
metric_max_over_ground_truths(compute_f1, prediction, reference)
for prediction, reference in zip(predictions, references)
]
return {
"exact_match": sum(exact_scores) / len(exact_scores),
"f1": sum(f1_scores) / len(f1_scores),
}
# Source: https://github.com/StonyBrookNLP/musique/blob/main/metrics/answer.py
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
return re.sub(regex, " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def get_tokens(s):
if not s:
return []
return normalize_answer(s).split()
def compute_exact(a_gold, a_pred):
return int(normalize_answer(a_gold) == normalize_answer(a_pred))
def compute_f1(a_gold, a_pred):
gold_toks = get_tokens(a_gold)
pred_toks = get_tokens(a_pred)
common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
num_same = sum(common.values())
if len(gold_toks) == 0 or len(pred_toks) == 0:
# If either is no-answer, then F1 is 1 if they agree, 0 otherwise
return int(gold_toks == pred_toks)
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_toks)
recall = 1.0 * num_same / len(gold_toks)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def metric_max_over_ground_truths(
metric_fn: Callable[[str, str], float],
prediction: str,
ground_truths: list[str],
) -> float:
scores_for_ground_truths = [
metric_fn(prediction, ground_truth) for ground_truth in ground_truths
]
return max(scores_for_ground_truths)