table_markdown / metric.py
maqiuping59's picture
Rename accuracy.py to metric.py
130cdb2 verified
raw
history blame
6.58 kB
import re
import json
import evaluate
import datasets
_DESCRIPTION = """
Table evaluation metrics for assessing the matching degree between predicted and reference tables. It calculates the following metrics:
1. Precision: The ratio of correctly predicted cells to the total number of cells in the predicted table
2. Recall: The ratio of correctly predicted cells to the total number of cells in the reference table
3. F1 Score: The harmonic mean of precision and recall
These metrics help evaluate the accuracy of table data extraction or generation.
"""
_KWARGS_DESCRIPTION = """
Args:
predictions (`str`): Predicted table in Markdown format.
references (`str`): Reference table in Markdown format.
Returns:
dict: A dictionary containing the following metrics:
- precision (`float`): Precision score, range [0,1]
- recall (`float`): Recall score, range [0,1]
- f1 (`float`): F1 score, range [0,1]
- true_positives (`int`): Number of correctly predicted cells
- false_positives (`int`): Number of incorrectly predicted cells
- false_negatives (`int`): Number of cells that were not predicted
Examples:
>>> accuracy_metric = evaluate.load("accuracy")
>>> results = accuracy_metric.compute(
... predictions="| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 |",
... references="| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 1 | 6 | 7 | 5 | 9 || wage | 1 | 5 | 2 | 8 | 5 |"
... )
>>> print(results)
{'precision': 0.7, 'recall': 0.7, 'f1': 0.7, 'true_positives': 7, 'false_positives': 3, 'false_negatives': 3}
"""
_CITATION = """
@article{scikit-learn,
title={Scikit-learn: Machine Learning in {P}ython},
author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V.
and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P.
and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and
Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
journal={Journal of Machine Learning Research},
volume={12},
pages={2825--2830},
year={2011}
}
"""
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class Accuracy(evaluate.Metric):
def _info(self):
return evaluate.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Value("string"),
"references": datasets.Value("string"),
}
),
reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"],
)
def _extract_markdown_table(self,text):
text = text.replace('\n', '')
text = text.replace(" ","")
pattern = r'\|(?:[^|]+\|)+[^|]+\|'
matches = re.findall(pattern, text)
if matches:
return ''.join(matches)
return None
def _table_to_dict(self,table_str):
result_dict = {}
table_str = table_str.lstrip("|").rstrip("|")
parts = table_str.split('||')
parts = [part for part in parts if "--" not in part]
legends = parts[0].split("|")
rows = len(parts)
if rows == 2:
nums = parts[1].split("|")
for i in range(len(nums)):
result_dict[legends[i]]=float(nums[i])
elif rows >=3:
for i in range(1,rows):
pre_row = parts[i]
pre_row = pre_row.split("|")
label = pre_row[0]
result_dict[label] = {}
for j in range(1,len(pre_row)):
result_dict[label][legends[j-1]] = float(pre_row[j])
else:
return None
return result_dict
def _markdown_to_dict(self,markdown_str):
table_str = self._extract_markdown_table(markdown_str)
if table_str:
return self._table_to_dict(table_str)
else:
return None
def _calculate_table_metrics(self,pred_table, true_table):
true_positives = 0
false_positives = 0
false_negatives = 0
for key, pred_value in pred_table.items():
if key in true_table:
true_value = true_table[key]
if isinstance(pred_value, dict) and isinstance(true_value, dict):
nested_metrics = self._calculate_table_metrics(pred_value, true_value)
true_positives += nested_metrics['true_positives']
false_positives += nested_metrics['false_positives']
false_negatives += nested_metrics['false_negatives']
elif pred_value == true_value:
true_positives += 1
else:
false_positives += 1
false_negatives += 1
else:
false_positives += 1
for key in true_table:
if key not in pred_table:
false_negatives += 1
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
return {
'precision': precision,
'recall': recall,
'f1': f1,
'true_positives': true_positives,
'false_positives': false_positives,
'false_negatives': false_negatives
}
def _compute(self, predictions, references):
predictions = "".join(predictions)
references = "".join(references)
return self._calculate_table_metrics(self._markdown_to_dict(predictions), self._markdown_to_dict(references))
def main():
accuracy_metric = Accuracy()
results = accuracy_metric.compute(
predictions=["""
| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 |
"""],
references=["""
| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 1 | 6 | 7 | 5 | 9 || wage | 1 | 5 | 2 | 8 | 5 |
"""],
)
print(results)
if __name__ == '__main__':
main()