|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Exact Match metric.""" |
|
import re |
|
import string |
|
|
|
import datasets |
|
import numpy as np |
|
|
|
import evaluate |
|
|
|
|
|
_DESCRIPTION = """ |
|
returns a score that indicates how close the bash command generated is to the actual command with a perfect score out of 1.0 |
|
""" |
|
|
|
_KWARGS_DESCRIPTION = """ |
|
Args: |
|
predictions: List of predicted texts. |
|
references: List of reference texts. |
|
regexes_to_ignore: List, defaults to None. Regex expressions of characters to |
|
ignore when calculating the exact matches. Note: these regexes are removed |
|
from the input data before the changes based on the options below (e.g. ignore_case, |
|
ignore_punctuation, ignore_numbers) are applied. |
|
ignore_case: Boolean, defaults to False. If true, turns everything |
|
to lowercase so that capitalization differences are ignored. |
|
ignore_punctuation: Boolean, defaults to False. If true, removes all punctuation before |
|
comparing predictions and references. |
|
ignore_numbers: Boolean, defaults to False. If true, removes all punctuation before |
|
comparing predictions and references. |
|
Returns: |
|
exact_match: Dictionary containing exact_match rate. Possible values are between 0.0 and 1.0, inclusive. |
|
Examples: |
|
>>> exact_match = evaluate.load("exact_match") |
|
>>> refs = ["the cat", "theater", "YELLING", "agent007"] |
|
>>> preds = ["cat?", "theater", "yelling", "agent"] |
|
>>> results = exact_match.compute(references=refs, predictions=preds) |
|
>>> print(round(results["exact_match"], 2)) |
|
0.25 |
|
>>> exact_match = evaluate.load("exact_match") |
|
>>> refs = ["the cat", "theater", "YELLING", "agent007"] |
|
>>> preds = ["cat?", "theater", "yelling", "agent"] |
|
>>> results = exact_match.compute(references=refs, predictions=preds, regexes_to_ignore=["the ", "yell"], ignore_case=True, ignore_punctuation=True) |
|
>>> print(round(results["exact_match"], 2)) |
|
0.5 |
|
>>> exact_match = evaluate.load("exact_match") |
|
>>> refs = ["the cat", "theater", "YELLING", "agent007"] |
|
>>> preds = ["cat?", "theater", "yelling", "agent"] |
|
>>> results = exact_match.compute(references=refs, predictions=preds, regexes_to_ignore=["the ", "yell", "YELL"], ignore_case=True, ignore_punctuation=True) |
|
>>> print(round(results["exact_match"], 2)) |
|
0.75 |
|
>>> exact_match = evaluate.load("exact_match") |
|
>>> refs = ["the cat", "theater", "YELLING", "agent007"] |
|
>>> preds = ["cat?", "theater", "yelling", "agent"] |
|
>>> results = exact_match.compute(references=refs, predictions=preds, regexes_to_ignore=["the ", "yell", "YELL"], ignore_case=True, ignore_punctuation=True, ignore_numbers=True) |
|
>>> print(round(results["exact_match"], 2)) |
|
1.0 |
|
>>> exact_match = evaluate.load("exact_match") |
|
>>> refs = ["The cat sat on the mat.", "Theaters are great.", "It's like comparing oranges and apples."] |
|
>>> preds = ["The cat sat on the mat?", "Theaters are great.", "It's like comparing apples and oranges."] |
|
>>> results = exact_match.compute(references=refs, predictions=preds) |
|
>>> print(round(results["exact_match"], 2)) |
|
0.33 |
|
""" |
|
|
|
_CITATION = """ |
|
""" |
|
|
|
|
|
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
|
class nl2bash_m(evaluate.Metric): |
|
def _info(self): |
|
return evaluate.MetricInfo( |
|
description=_DESCRIPTION, |
|
citation=_CITATION, |
|
inputs_description=_KWARGS_DESCRIPTION, |
|
features=datasets.Features( |
|
{ |
|
"predictions": datasets.Value("string", id="sequence"), |
|
"references": datasets.Value("string", id="sequence"), |
|
} |
|
), |
|
reference_urls=[], |
|
) |
|
|
|
def get_score(self, pred, ref): |
|
if not pred and not ref: return 1 |
|
cor = 0 |
|
for i in range(min(len(pred), len(ref))): |
|
if (pred[i] == ref[i]): |
|
cor += 1 |
|
|
|
return cor/max(len(pred), len(ref)) |
|
|
|
def _compute( |
|
self, |
|
predictions, |
|
references, |
|
cmd_weight = 0.65, |
|
opt_weight = 0.25, |
|
arg_weight = 0.15, |
|
ignore_case=False, |
|
ignore_numbers=False, |
|
): |
|
|
|
predictions = np.asarray(predictions) |
|
references = np.asarray(references) |
|
|
|
if ignore_case: |
|
predictions = np.char.lower(predictions) |
|
references = np.char.lower(references) |
|
|
|
if ignore_numbers: |
|
repl_table = string.digits.maketrans("", "", string.digits) |
|
predictions = np.char.translate(predictions, table=repl_table) |
|
references = np.char.translate(references, table=repl_table) |
|
|
|
|
|
final_score = 0 |
|
|
|
for pred, ref in zip(predictions, references): |
|
print(pred, ref) |
|
pred_words, ref_words = pred.split(), ref[0].split() |
|
|
|
cmd_corr = 1 if pred_words.pop(0)==ref_words.pop(0) else 0 |
|
|
|
|
|
pred_option = [ x for x in pred_words if x[0] == '-'] |
|
ref_option = [ x for x in ref_words if x[0] == '-'] |
|
|
|
|
|
pred_args = [ x for x in pred_words if x[0] != '-'] |
|
ref_args = [ x for x in ref_words if x[0] != '-'] |
|
|
|
|
|
cmd_score = cmd_weight * cmd_corr |
|
opt_score = opt_weight * self.get_score(pred_option, ref_option) |
|
arg_score = arg_weight * self.get_score(pred_args, ref_args) |
|
|
|
score = cmd_score + opt_score + arg_score |
|
final_score += score |
|
print(score) |
|
|
|
final_score = final_score/len(predictions) |
|
print("f_s: ", final_score) |
|
|
|
|
|
return {"nl2bash_m": (final_score)} |