|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""nl2bash metric.""" |
|
import re |
|
import string |
|
|
|
import datasets |
|
import numpy as np |
|
|
|
import evaluate |
|
|
|
|
|
_DESCRIPTION = """ |
|
returns a score that indicates how close the bash command generated is to the actual command with a perfect score out of 1.0 |
|
""" |
|
|
|
_KWARGS_DESCRIPTION = """ |
|
Args: |
|
predictions: List of predicted texts. |
|
references: List of reference texts. |
|
cmd_weight: The weight you want to put on getting the command correct |
|
opt_weight: The weight you want to put on getting the option correct |
|
arg_weight: The weight you want to put on getting the arg correct |
|
ignore_case=False, |
|
ignore_numbers=False, |
|
Returns: |
|
nl2bash metric: Dictionary containing nl2bash score. Possible values are between 0.0 and 1.0, inclusive. |
|
Examples: |
|
|
|
|
|
>>> metric = evaluate.load("Josh98/nl2bash_m") |
|
>>> preds = ["ls -l /home/userr", "ls -l /home/josh", "lss /home/josh some argument"] |
|
>>> refs = [["ls -l /home/user"], ["ls -l --v /home/josh"], ["ls /home/josh"]] |
|
>>> results = exact_match.compute(references=refs, predictions=preds) |
|
>>> print(round(results["nl2bash"], 2)) |
|
0.708 |
|
""" |
|
|
|
_CITATION = """ |
|
""" |
|
|
|
|
|
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
|
class nl2bash_m(evaluate.Metric): |
|
def _info(self): |
|
return evaluate.MetricInfo( |
|
description=_DESCRIPTION, |
|
citation=_CITATION, |
|
inputs_description=_KWARGS_DESCRIPTION, |
|
features=[ |
|
datasets.Features( |
|
{ |
|
"predictions": datasets.Value("string", id="sequence"), |
|
"references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"), |
|
} |
|
), |
|
datasets.Features( |
|
{ |
|
"predictions": datasets.Value("string", id="sequence"), |
|
"references": datasets.Value("string", id="sequence"), |
|
} |
|
), |
|
], |
|
reference_urls=[], |
|
) |
|
|
|
def get_score(self, pred, ref): |
|
if not pred and not ref: return 1 |
|
cor = 0 |
|
for i in range(min(len(pred), len(ref))): |
|
if (pred[i] == ref[i]): |
|
cor += 1 |
|
|
|
return cor/max(len(pred), len(ref)) |
|
|
|
def _compute( |
|
self, |
|
predictions, |
|
references, |
|
cmd_weight = 0.65, |
|
opt_weight = 0.25, |
|
arg_weight = 0.15, |
|
ignore_case=True, |
|
ignore_numbers=True, |
|
): |
|
|
|
predictions = np.asarray(predictions) |
|
references = np.asarray(references) |
|
|
|
if ignore_case: |
|
predictions = np.char.lower(predictions) |
|
references = np.char.lower(references) |
|
|
|
if ignore_numbers: |
|
repl_table = string.digits.maketrans("", "", string.digits) |
|
predictions = np.char.translate(predictions, table=repl_table) |
|
references = np.char.translate(references, table=repl_table) |
|
|
|
|
|
final_score = 0 |
|
for pred, refs in zip(predictions, references): |
|
best_score = 0 |
|
if len(pred) == 0 and min([len(ref) for ref in refs]) == 0: |
|
best_score = 1 |
|
elif len(pred) == 0 or min([len(ref) for ref in refs]) == 0: |
|
best_score = 0 |
|
else: |
|
for ref in refs: |
|
pred_words, ref_words = pred.split(), ref.split() |
|
|
|
|
|
|
|
cmd_corr = 1 if pred_words.pop(0)==ref_words.pop(0) else 0 |
|
|
|
|
|
pred_option = [ x for x in pred_words if x[0] == '-'] |
|
ref_option = [ x for x in ref_words if x[0] == '-'] |
|
|
|
|
|
pred_args = [ x for x in pred_words if x[0] != '-'] |
|
ref_args = [ x for x in ref_words if x[0] != '-'] |
|
|
|
|
|
cmd_score = cmd_weight * cmd_corr |
|
opt_score = opt_weight * self.get_score(pred_option, ref_option) |
|
arg_score = arg_weight * self.get_score(pred_args, ref_args) |
|
|
|
score = cmd_score + opt_score + arg_score |
|
best_score = max(best_score, score) |
|
|
|
final_score += best_score |
|
|
|
final_score = final_score/len(predictions) |
|
|
|
return {"nl2bash_m": (final_score)} |