Spaces:
Sleeping
Sleeping
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""TODO: Add a description here.""" | |
import evaluate | |
import datasets | |
# TODO: Add BibTeX citation | |
_CITATION = """\ | |
@InProceedings{huggingface:module, | |
title = {A great new module}, | |
authors={huggingface, Inc.}, | |
year={2020} | |
} | |
""" | |
# TODO: Add description of the module here | |
_DESCRIPTION = """\ | |
Evaluate structured output formatting for generated text. | |
- considers header / column / tag / key names | |
- DOES NOT consider the cell / row values specifically | |
Formats: | |
- [] Custom | |
- [] Markdown tables | |
- [] HTML tables | |
- [] JSON | |
- [] XML | |
- [] CSV / TSV | |
""" | |
# TODO: Add description of the arguments of the module here | |
_KWARGS_DESCRIPTION = """ | |
Calculates how well the `structure` of the predictions matches the `structure` of the references. | |
Args: | |
predictions: list of strings to score. Each predictions | |
should be a string with tokens separated by spaces. | |
references: list of reference for each prediction. Each | |
reference should be a string with tokens separated by spaces. | |
invariance: bool, whether to consider the order of the columns / tags / keys. | |
Returns: | |
kaushiks_criteria: kaushiks_criteria score. | |
Examples: | |
Examples should be written in doctest format, and should illustrate how | |
to use the function. | |
>>> my_new_module = evaluate.load("DoctorSlimm/kaushiks_criteria") | |
>>> results = my_new_module.compute( | |
references=['<table><tr><td>1</td><td>2</td></tr></table>'], | |
predictions=['<table><tr><td>1</td><td>2</td></tr></table>'] | |
) | |
>>> print(results) | |
{'accuracy': 1.0} | |
""" | |
# TODO: Define external resources urls if needed | |
BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt" | |
class kaushiks_criteria(evaluate.Metric): | |
"""TODO: Short description of my evaluation module.""" | |
def _info(self): | |
# TODO: Specifies the evaluate.EvaluationModuleInfo object | |
return evaluate.MetricInfo( | |
# This is the description that will appear on the modules page. | |
module_type="metric", | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
inputs_description=_KWARGS_DESCRIPTION, | |
# This defines the format of each prediction and reference | |
features=datasets.Features({ | |
'predictions': datasets.Value('string'), | |
'references': datasets.Value('string'), | |
}), | |
# Homepage of the module for documentation | |
homepage="http://module.homepage", | |
# Additional links to the codebase or references | |
codebase_urls=["http://github.com/path/to/codebase/of/new_module"], | |
reference_urls=["http://path.to.reference.url/new_module"] | |
) | |
def _download_and_prepare(self, dl_manager): | |
"""Optional: download external resources useful to compute the scores""" | |
# TODO: Download external resources if needed | |
import evaluate | |
evaluate.load('exact_match') | |
pass | |
def normalize_fn(self, example, text_field='text'): | |
""" | |
parse output text into headers, rows, and records | |
- parse row by row (incomplete rows) | |
:param example: | |
:return: | |
Note: this does not handle special tokens | |
expected input format: | |
| col1 | col2 | col3 | <- start and trailing pipes required | |
| ---- | ---- | ---- | <- exactly 3x '-' per column | |
| val1 | val2 | val3 | | |
| ... | ... | ... | | |
""" | |
headers, sep_row, row_counts = "", "", [] | |
rows = dict(example)[text_field].strip().split('\n') | |
# parse headers | |
if len(rows) > 0: | |
headers = rows[0].strip() | |
# parse separator row | |
if len(rows) > 1: | |
sep_row = rows[1].strip() | |
# parse row cell counts | |
if len(rows) > 2: | |
data_rows = rows[2:] | |
for row in data_rows: | |
cell_counts = len(row.strip('|').split('|')) | |
row_counts.append(str(int(cell_counts))) | |
return { | |
'headers': headers, | |
'sep_row': sep_row, | |
'row_counts': ''.join(row_counts) | |
} | |
def _compute(self, predictions, references, num_proc=None): | |
""" | |
compute the quality of the output format with respect to the reference format | |
* column names match | |
* column order matches | |
* total row count | |
* number of cells in each row | |
:param predictions: | |
:param references: | |
:return: | |
""" | |
from datasets import Dataset, DatasetDict | |
pred_ds = Dataset.from_dict({'text': predictions}) | |
refs_ds = Dataset.from_dict({'text': references}) | |
proc_ds = DatasetDict({'predictions': pred_ds, 'references': refs_ds}) | |
proc_ds = proc_ds.map( | |
self.normalize_fn, | |
num_proc=num_proc, | |
load_from_cache_file=False | |
) | |
# compare headers | |
exact_match = evaluate.load('exact_match') | |
exact_match_headers = exact_match.compute( | |
predictions=proc_ds['predictions']['headers'], | |
references=proc_ds['references']['headers'] | |
)['exact_match'] | |
# compare separator row | |
exact_match_sep_row = exact_match.compute( | |
predictions=proc_ds['predictions']['sep_row'], | |
references=proc_ds['references']['sep_row'] | |
)['exact_match'] | |
# compare row counts | |
exact_match_row_counts = exact_match.compute( | |
predictions=proc_ds['predictions']['row_counts'], | |
references=proc_ds['references']['row_counts'] | |
)['exact_match'] | |
# compute kaushiks_criteria | |
score = (exact_match_headers + exact_match_sep_row + exact_match_row_counts) / 3.0 | |
return { | |
'kaushiks_criteria': score, | |
'exact_match_headers': exact_match_headers, | |
'exact_match_sep_row': exact_match_sep_row, | |
'exact_match_row_counts': exact_match_row_counts, | |
} |