wiki_split / wiki_split.py
lvwerra's picture
lvwerra HF staff
Update Space (evaluate main: dfdd0cc0)
b095d0a
raw
history blame
14.7 kB
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" WIKI_SPLIT metric."""
import re
import string
from collections import Counter
import datasets
import sacrebleu
import sacremoses
from packaging import version
import evaluate
_CITATION = """
@inproceedings{xu-etal-2016-optimizing,
title = {Optimizing Statistical Machine Translation for Text Simplification},
authors={Xu, Wei and Napoles, Courtney and Pavlick, Ellie and Chen, Quanze and Callison-Burch, Chris},
journal = {Transactions of the Association for Computational Linguistics},
volume = {4},
year={2016},
url = {https://www.aclweb.org/anthology/Q16-1029},
pages = {401--415
},
@inproceedings{post-2018-call,
title = "A Call for Clarity in Reporting {BLEU} Scores",
author = "Post, Matt",
booktitle = "Proceedings of the Third Conference on Machine Translation: Research Papers",
month = oct,
year = "2018",
address = "Belgium, Brussels",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/W18-6319",
pages = "186--191",
}
"""
_DESCRIPTION = """\
WIKI_SPLIT is the combination of three metrics SARI, EXACT and SACREBLEU
It can be used to evaluate the quality of machine-generated texts.
"""
_KWARGS_DESCRIPTION = """
Calculates sari score (between 0 and 100) given a list of source and predicted
sentences, and a list of lists of reference sentences. It also computes the BLEU score as well as the exact match score.
Args:
sources: list of source sentences where each sentence should be a string.
predictions: list of predicted sentences where each sentence should be a string.
references: list of lists of reference sentences where each sentence should be a string.
Returns:
sari: sari score
sacrebleu: sacrebleu score
exact: exact score
Examples:
>>> sources=["About 95 species are currently accepted ."]
>>> predictions=["About 95 you now get in ."]
>>> references=[["About 95 species are currently known ."]]
>>> wiki_split = evaluate.load("wiki_split")
>>> results = wiki_split.compute(sources=sources, predictions=predictions, references=references)
>>> print(results)
{'sari': 21.805555555555557, 'sacrebleu': 14.535768424205482, 'exact': 0.0}
"""
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
return re.sub(regex, " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def compute_exact(a_gold, a_pred):
return int(normalize_answer(a_gold) == normalize_answer(a_pred))
def compute_em(predictions, references):
scores = [any([compute_exact(ref, pred) for ref in refs]) for pred, refs in zip(predictions, references)]
return (sum(scores) / len(scores)) * 100
def SARIngram(sgrams, cgrams, rgramslist, numref):
rgramsall = [rgram for rgrams in rgramslist for rgram in rgrams]
rgramcounter = Counter(rgramsall)
sgramcounter = Counter(sgrams)
sgramcounter_rep = Counter()
for sgram, scount in sgramcounter.items():
sgramcounter_rep[sgram] = scount * numref
cgramcounter = Counter(cgrams)
cgramcounter_rep = Counter()
for cgram, ccount in cgramcounter.items():
cgramcounter_rep[cgram] = ccount * numref
# KEEP
keepgramcounter_rep = sgramcounter_rep & cgramcounter_rep
keepgramcountergood_rep = keepgramcounter_rep & rgramcounter
keepgramcounterall_rep = sgramcounter_rep & rgramcounter
keeptmpscore1 = 0
keeptmpscore2 = 0
for keepgram in keepgramcountergood_rep:
keeptmpscore1 += keepgramcountergood_rep[keepgram] / keepgramcounter_rep[keepgram]
# Fix an alleged bug [2] in the keep score computation.
# keeptmpscore2 += keepgramcountergood_rep[keepgram] / keepgramcounterall_rep[keepgram]
keeptmpscore2 += keepgramcountergood_rep[keepgram]
# Define 0/0=1 instead of 0 to give higher scores for predictions that match
# a target exactly.
keepscore_precision = 1
keepscore_recall = 1
if len(keepgramcounter_rep) > 0:
keepscore_precision = keeptmpscore1 / len(keepgramcounter_rep)
if len(keepgramcounterall_rep) > 0:
# Fix an alleged bug [2] in the keep score computation.
# keepscore_recall = keeptmpscore2 / len(keepgramcounterall_rep)
keepscore_recall = keeptmpscore2 / sum(keepgramcounterall_rep.values())
keepscore = 0
if keepscore_precision > 0 or keepscore_recall > 0:
keepscore = 2 * keepscore_precision * keepscore_recall / (keepscore_precision + keepscore_recall)
# DELETION
delgramcounter_rep = sgramcounter_rep - cgramcounter_rep
delgramcountergood_rep = delgramcounter_rep - rgramcounter
delgramcounterall_rep = sgramcounter_rep - rgramcounter
deltmpscore1 = 0
deltmpscore2 = 0
for delgram in delgramcountergood_rep:
deltmpscore1 += delgramcountergood_rep[delgram] / delgramcounter_rep[delgram]
deltmpscore2 += delgramcountergood_rep[delgram] / delgramcounterall_rep[delgram]
# Define 0/0=1 instead of 0 to give higher scores for predictions that match
# a target exactly.
delscore_precision = 1
if len(delgramcounter_rep) > 0:
delscore_precision = deltmpscore1 / len(delgramcounter_rep)
# ADDITION
addgramcounter = set(cgramcounter) - set(sgramcounter)
addgramcountergood = set(addgramcounter) & set(rgramcounter)
addgramcounterall = set(rgramcounter) - set(sgramcounter)
addtmpscore = 0
for addgram in addgramcountergood:
addtmpscore += 1
# Define 0/0=1 instead of 0 to give higher scores for predictions that match
# a target exactly.
addscore_precision = 1
addscore_recall = 1
if len(addgramcounter) > 0:
addscore_precision = addtmpscore / len(addgramcounter)
if len(addgramcounterall) > 0:
addscore_recall = addtmpscore / len(addgramcounterall)
addscore = 0
if addscore_precision > 0 or addscore_recall > 0:
addscore = 2 * addscore_precision * addscore_recall / (addscore_precision + addscore_recall)
return (keepscore, delscore_precision, addscore)
def SARIsent(ssent, csent, rsents):
numref = len(rsents)
s1grams = ssent.split(" ")
c1grams = csent.split(" ")
s2grams = []
c2grams = []
s3grams = []
c3grams = []
s4grams = []
c4grams = []
r1gramslist = []
r2gramslist = []
r3gramslist = []
r4gramslist = []
for rsent in rsents:
r1grams = rsent.split(" ")
r2grams = []
r3grams = []
r4grams = []
r1gramslist.append(r1grams)
for i in range(0, len(r1grams) - 1):
if i < len(r1grams) - 1:
r2gram = r1grams[i] + " " + r1grams[i + 1]
r2grams.append(r2gram)
if i < len(r1grams) - 2:
r3gram = r1grams[i] + " " + r1grams[i + 1] + " " + r1grams[i + 2]
r3grams.append(r3gram)
if i < len(r1grams) - 3:
r4gram = r1grams[i] + " " + r1grams[i + 1] + " " + r1grams[i + 2] + " " + r1grams[i + 3]
r4grams.append(r4gram)
r2gramslist.append(r2grams)
r3gramslist.append(r3grams)
r4gramslist.append(r4grams)
for i in range(0, len(s1grams) - 1):
if i < len(s1grams) - 1:
s2gram = s1grams[i] + " " + s1grams[i + 1]
s2grams.append(s2gram)
if i < len(s1grams) - 2:
s3gram = s1grams[i] + " " + s1grams[i + 1] + " " + s1grams[i + 2]
s3grams.append(s3gram)
if i < len(s1grams) - 3:
s4gram = s1grams[i] + " " + s1grams[i + 1] + " " + s1grams[i + 2] + " " + s1grams[i + 3]
s4grams.append(s4gram)
for i in range(0, len(c1grams) - 1):
if i < len(c1grams) - 1:
c2gram = c1grams[i] + " " + c1grams[i + 1]
c2grams.append(c2gram)
if i < len(c1grams) - 2:
c3gram = c1grams[i] + " " + c1grams[i + 1] + " " + c1grams[i + 2]
c3grams.append(c3gram)
if i < len(c1grams) - 3:
c4gram = c1grams[i] + " " + c1grams[i + 1] + " " + c1grams[i + 2] + " " + c1grams[i + 3]
c4grams.append(c4gram)
(keep1score, del1score, add1score) = SARIngram(s1grams, c1grams, r1gramslist, numref)
(keep2score, del2score, add2score) = SARIngram(s2grams, c2grams, r2gramslist, numref)
(keep3score, del3score, add3score) = SARIngram(s3grams, c3grams, r3gramslist, numref)
(keep4score, del4score, add4score) = SARIngram(s4grams, c4grams, r4gramslist, numref)
avgkeepscore = sum([keep1score, keep2score, keep3score, keep4score]) / 4
avgdelscore = sum([del1score, del2score, del3score, del4score]) / 4
avgaddscore = sum([add1score, add2score, add3score, add4score]) / 4
finalscore = (avgkeepscore + avgdelscore + avgaddscore) / 3
return finalscore
def normalize(sentence, lowercase: bool = True, tokenizer: str = "13a", return_str: bool = True):
# Normalization is requried for the ASSET dataset (one of the primary
# datasets in sentence simplification) to allow using space
# to split the sentence. Even though Wiki-Auto and TURK datasets,
# do not require normalization, we do it for consistency.
# Code adapted from the EASSE library [1] written by the authors of the ASSET dataset.
# [1] https://github.com/feralvam/easse/blob/580bba7e1378fc8289c663f864e0487188fe8067/easse/utils/preprocessing.py#L7
if lowercase:
sentence = sentence.lower()
if tokenizer in ["13a", "intl"]:
if version.parse(sacrebleu.__version__).major >= 2:
normalized_sent = sacrebleu.metrics.bleu._get_tokenizer(tokenizer)()(sentence)
else:
normalized_sent = sacrebleu.TOKENIZERS[tokenizer]()(sentence)
elif tokenizer == "moses":
normalized_sent = sacremoses.MosesTokenizer().tokenize(sentence, return_str=True, escape=False)
elif tokenizer == "penn":
normalized_sent = sacremoses.MosesTokenizer().penn_tokenize(sentence, return_str=True)
else:
normalized_sent = sentence
if not return_str:
normalized_sent = normalized_sent.split()
return normalized_sent
def compute_sari(sources, predictions, references):
if not (len(sources) == len(predictions) == len(references)):
raise ValueError("Sources length must match predictions and references lengths.")
sari_score = 0
for src, pred, refs in zip(sources, predictions, references):
sari_score += SARIsent(normalize(src), normalize(pred), [normalize(sent) for sent in refs])
sari_score = sari_score / len(predictions)
return 100 * sari_score
def compute_sacrebleu(
predictions,
references,
smooth_method="exp",
smooth_value=None,
force=False,
lowercase=False,
use_effective_order=False,
):
references_per_prediction = len(references[0])
if any(len(refs) != references_per_prediction for refs in references):
raise ValueError("Sacrebleu requires the same number of references for each prediction")
transformed_references = [[refs[i] for refs in references] for i in range(references_per_prediction)]
output = sacrebleu.corpus_bleu(
predictions,
transformed_references,
smooth_method=smooth_method,
smooth_value=smooth_value,
force=force,
lowercase=lowercase,
use_effective_order=use_effective_order,
)
return output.score
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class WikiSplit(evaluate.EvaluationModule):
def _info(self):
return evaluate.EvaluationModuleInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=[
datasets.Features(
{
"predictions": datasets.Value("string", id="sequence"),
"references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"),
}
),
datasets.Features(
{
"predictions": datasets.Value("string", id="sequence"),
"references": datasets.Value("string", id="sequence"),
}
),
],
codebase_urls=[
"https://github.com/huggingface/transformers/blob/master/src/transformers/data/metrics/squad_metrics.py",
"https://github.com/cocoxu/simplification/blob/master/SARI.py",
"https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/sari_hook.py",
"https://github.com/mjpost/sacreBLEU",
],
reference_urls=[
"https://www.aclweb.org/anthology/Q16-1029.pdf",
"https://github.com/mjpost/sacreBLEU",
"https://en.wikipedia.org/wiki/BLEU",
"https://towardsdatascience.com/evaluating-text-output-in-nlp-bleu-at-your-own-risk-e8609665a213",
],
)
def _compute(self, sources, predictions, references):
# if only one reference is provided make sure we still use list of lists
if isinstance(references[0], str):
references = [[ref] for ref in references]
result = {}
result.update({"sari": compute_sari(sources=sources, predictions=predictions, references=references)})
result.update({"sacrebleu": compute_sacrebleu(predictions=predictions, references=references)})
result.update({"exact": compute_em(predictions=predictions, references=references)})
return result