Spaces:
Running
Running
import re | |
import json | |
import evaluate | |
import datasets | |
import pprint | |
_DESCRIPTION = """ | |
Table evaluation metrics for assessing the matching degree between predicted and reference tables. It calculates the following metrics: | |
1. Precision: The ratio of correctly predicted cells to the total number of cells in the predicted table | |
2. Recall: The ratio of correctly predicted cells to the total number of cells in the reference table | |
3. F1 Score: The harmonic mean of precision and recall | |
These metrics help evaluate the accuracy of table data extraction or generation. | |
""" | |
_KWARGS_DESCRIPTION = """ | |
Args: | |
predictions (`str`): Predicted table in Markdown format. | |
references (`str`): Reference table in Markdown format. | |
Returns: | |
dict: A dictionary containing the following metrics: | |
- precision (`float`): Precision score, range [0,1] | |
- recall (`float`): Recall score, range [0,1] | |
- f1 (`float`): F1 score, range [0,1] | |
- true_positives (`int`): Number of correctly predicted cells | |
- false_positives (`int`): Number of incorrectly predicted cells | |
- false_negatives (`int`): Number of cells that were not predicted | |
Examples: | |
>>> accuracy_metric = evaluate.load("accuracy") | |
>>> results = accuracy_metric.compute( | |
... predictions="| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 |", | |
... references="| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 1 | 6 | 7 | 5 | 9 || wage | 1 | 5 | 2 | 8 | 5 |" | |
... ) | |
>>> print(results) | |
{'precision': 0.7, 'recall': 0.7, 'f1': 0.7, 'true_positives': 7, 'false_positives': 3, 'false_negatives': 3} | |
""" | |
_CITATION = """ | |
""" | |
class Accuracy(evaluate.Metric): | |
def _info(self): | |
return evaluate.MetricInfo( | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
inputs_description=_KWARGS_DESCRIPTION, | |
features=datasets.Features( | |
{ | |
"predictions": datasets.Value("string"), | |
"references": datasets.Value("string"), | |
} | |
), | |
reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"], | |
) | |
def _extract_markdown_table(self,text): | |
text = text.replace('\n', '') | |
text = text.replace(" ","") | |
if text.startswith("||"): | |
text = "|header"+text[1:] | |
pattern = r'\|(?:[^|]+\|)+[^|]+\|' | |
matches = re.findall(pattern, text) | |
if matches: | |
return ''.join(matches) | |
return None | |
def _table_prase(self,table_str): | |
print(table_str) | |
table_str = table_str.split("||") | |
table_str = [item.strip("|").split("|") for item in table_str] | |
parse_result = [] | |
for i in range(len(table_str)): | |
pre_row = table_str[i] | |
for j in range(len(pre_row)): | |
index = [] | |
try: | |
value = float(pre_row[j]) | |
try: | |
float(pre_row[0]) | |
except ValueError as e: | |
index.append(pre_row[0]) | |
try: | |
float(table_str[0][j]) | |
except ValueError as e: | |
index.append(table_str[0][j]) | |
if len(index)>0: | |
parse_result.append([set(index),value]) | |
except: | |
continue | |
return parse_result | |
return None | |
def _markdown_to_table(self,markdown_str): | |
table_str = self._extract_markdown_table(markdown_str) | |
if table_str: | |
parse_result = self._table_prase(table_str) | |
return parse_result | |
return None | |
def _calculate_table_metrics(self, pred_table, true_table): | |
true_positives = 0 | |
false_positives = 0 | |
false_negatives = 0 | |
# print(f"pred_table:{pred_table}") | |
# print(f"true_table:{true_table}") | |
# Convert lists to dictionaries for easier comparison | |
pred_dict = {tuple(sorted(item[0])): item[1] for item in pred_table} | |
true_dict = {tuple(sorted(item[0])): item[1] for item in true_table} | |
# pprint.pprint(f"pred_dict:{pred_dict}") | |
# pprint.pprint(f"true_dict:{true_dict}") | |
# Compare predictions with true values | |
for key, pred_value in pred_dict.items(): | |
if key in true_dict: | |
true_value = true_dict[key] | |
if true_value == 0 and abs(pred_value) < 0.05: | |
true_positives += 1 | |
elif true_value != 0 and abs((pred_value - true_value) / true_value) < 0.05: | |
true_positives += 1 | |
else: | |
false_positives += 1 | |
false_negatives += 1 | |
else: | |
false_positives += 1 | |
# Count false negatives (items in true table but not in predictions) | |
for key in true_dict: | |
if key not in pred_dict: | |
false_negatives += 1 | |
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0 | |
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0 | |
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 | |
return { | |
'precision': precision, | |
'recall': recall, | |
'f1': f1, | |
'true_positives': true_positives, | |
'false_positives': false_positives, | |
'false_negatives': false_negatives | |
} | |
def _compute(self, predictions, references): | |
predictions = "".join(predictions) | |
references = "".join(references) | |
# print(predictions) | |
return self._calculate_table_metrics(self._markdown_to_table(predictions), self._markdown_to_table(references)) | |
def main(): | |
accuracy_metric = Accuracy() | |
# 计算指标 | |
results = accuracy_metric.compute( | |
# predictions=[""" | |
# check | Values || | 0 || lap | 0 | |
# """], # 预测的表格 | |
predictions=[""" | |
||view|denial|finger|check| | |
|--|--|--|--|--| | |
|tour|9|8|4|7| | |
|wall|4|9|5|7| | |
|sex|2|6|3|1| | |
"""], # 预测的表格 | |
references=[""" | |
|check|lap| | |
|--|--| | |
|80|40| | |
"""], # 参考的表格 | |
) | |
print(results) # 输出结果 | |
if __name__ == '__main__': | |
main() | |