table_markdown / table_markdown.py
maqiuping59's picture
Upload table_markdown.py
8d2a41b verified
raw
history blame
6.97 kB
import re
import json
import evaluate
import datasets
_DESCRIPTION = """
Table evaluation metrics for assessing the matching degree between predicted and reference tables. It calculates the following metrics:
1. Precision: The ratio of correctly predicted cells to the total number of cells in the predicted table
2. Recall: The ratio of correctly predicted cells to the total number of cells in the reference table
3. F1 Score: The harmonic mean of precision and recall
These metrics help evaluate the accuracy of table data extraction or generation.
"""
_KWARGS_DESCRIPTION = """
Args:
predictions (`str`): Predicted table in Markdown format.
references (`str`): Reference table in Markdown format.
Returns:
dict: A dictionary containing the following metrics:
- precision (`float`): Precision score, range [0,1]
- recall (`float`): Recall score, range [0,1]
- f1 (`float`): F1 score, range [0,1]
- true_positives (`int`): Number of correctly predicted cells
- false_positives (`int`): Number of incorrectly predicted cells
- false_negatives (`int`): Number of cells that were not predicted
Examples:
>>> accuracy_metric = evaluate.load("accuracy")
>>> results = accuracy_metric.compute(
... predictions="| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 |",
... references="| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 1 | 6 | 7 | 5 | 9 || wage | 1 | 5 | 2 | 8 | 5 |"
... )
>>> print(results)
{'precision': 0.7, 'recall': 0.7, 'f1': 0.7, 'true_positives': 7, 'false_positives': 3, 'false_negatives': 3}
"""
_CITATION = """
@article{scikit-learn,
title={Scikit-learn: Machine Learning in {P}ython},
author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V.
and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P.
and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and
Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
journal={Journal of Machine Learning Research},
volume={12},
pages={2825--2830},
year={2011}
}
"""
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class Accuracy(evaluate.Metric):
def _info(self):
return evaluate.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Value("string"),
"references": datasets.Value("string"),
}
),
reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"],
)
def _extract_markdown_table(self,text):
text = text.replace('\n', '')
text = text.replace(" ","")
pattern = r'\|(?:[^|]+\|)+[^|]+\|'
matches = re.findall(pattern, text)
if matches:
return ''.join(matches)
return None
def _table_to_dict(self,table_str):
result_dict = {}
table_str = table_str.lstrip("|").rstrip("|")
parts = table_str.split('||')
parts = [part for part in parts if "--" not in part]
legends = parts[0].split("|")
rows = len(parts)
if rows == 2:
nums = parts[1].split("|")
for i in range(len(nums)):
result_dict[legends[i]]=float(nums[i])
elif rows >=3:
for i in range(1,rows):
pre_row = parts[i]
pre_row = pre_row.split("|")
label = pre_row[0]
result_dict[label] = {}
for j in range(1,len(pre_row)):
result_dict[label][legends[j-1]] = float(pre_row[j])
else:
return None
return result_dict
def _markdown_to_dict(self,markdown_str):
table_str = self._extract_markdown_table(markdown_str)
if table_str:
return self._table_to_dict(table_str)
else:
return None
def _calculate_table_metrics(self,pred_table, true_table):
true_positives = 0
false_positives = 0
false_negatives = 0
# 遍历预测表格的所有键值对
for key, pred_value in pred_table.items():
if key in true_table:
true_value = true_table[key]
if isinstance(pred_value, dict) and isinstance(true_value, dict):
nested_metrics = self._calculate_table_metrics(pred_value, true_value)
true_positives += nested_metrics['true_positives']
false_positives += nested_metrics['false_positives']
false_negatives += nested_metrics['false_negatives']
# 如果值相等
elif pred_value == true_value:
true_positives += 1
else:
false_positives += 1
false_negatives += 1
else:
false_positives += 1
# 计算未匹配的真实值
for key in true_table:
if key not in pred_table:
false_negatives += 1
# 计算指标
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
return {
'precision': precision,
'recall': recall,
'f1': f1,
'true_positives': true_positives,
'false_positives': false_positives,
'false_negatives': false_negatives
}
def _compute(self, predictions, references):
predictions = "".join(predictions)
references = "".join(references)
return self._calculate_table_metrics(self._markdown_to_dict(predictions), self._markdown_to_dict(references))
def main():
accuracy_metric = Accuracy()
# 计算指标
results = accuracy_metric.compute(
predictions=["""
| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 |
"""], # 预测的表格
references=["""
| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 1 | 6 | 7 | 5 | 9 || wage | 1 | 5 | 2 | 8 | 5 |
"""], # 参考的表格
)
print(results) # 输出结果
if __name__ == '__main__':
main()