table_markdown / metric.py
maqiuping59's picture
Update metric.py
c2d3ba8 verified
import re
import json
import evaluate
import datasets
import pprint
_DESCRIPTION = """
Table evaluation metrics for assessing the matching degree between predicted and reference tables. It calculates the following metrics:
1. Precision: The ratio of correctly predicted cells to the total number of cells in the predicted table
2. Recall: The ratio of correctly predicted cells to the total number of cells in the reference table
3. F1 Score: The harmonic mean of precision and recall
These metrics help evaluate the accuracy of table data extraction or generation.
"""
_KWARGS_DESCRIPTION = """
Args:
predictions (`str`): Predicted table in Markdown format.
references (`str`): Reference table in Markdown format.
Returns:
dict: A dictionary containing the following metrics:
- precision (`float`): Precision score, range [0,1]
- recall (`float`): Recall score, range [0,1]
- f1 (`float`): F1 score, range [0,1]
- true_positives (`int`): Number of correctly predicted cells
- false_positives (`int`): Number of incorrectly predicted cells
- false_negatives (`int`): Number of cells that were not predicted
Examples:
>>> accuracy_metric = evaluate.load("accuracy")
>>> results = accuracy_metric.compute(
... predictions="| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 |",
... references="| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 1 | 6 | 7 | 5 | 9 || wage | 1 | 5 | 2 | 8 | 5 |"
... )
>>> print(results)
{'precision': 0.7, 'recall': 0.7, 'f1': 0.7, 'true_positives': 7, 'false_positives': 3, 'false_negatives': 3}
"""
_CITATION = """
"""
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class Accuracy(evaluate.Metric):
def _info(self):
return evaluate.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Value("string"),
"references": datasets.Value("string"),
}
),
reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"],
)
def _extract_markdown_table(self,text):
text = text.replace('\n', '')
text = text.replace(" ","")
if text.startswith("||"):
text = "|header"+text[1:]
pattern = r'\|(?:[^|]+\|)+[^|]+\|'
matches = re.findall(pattern, text)
if matches:
return ''.join(matches)
return None
def _table_prase(self,table_str):
print(table_str)
table_str = table_str.split("||")
table_str = [item.strip("|").split("|") for item in table_str]
parse_result = []
for i in range(len(table_str)):
pre_row = table_str[i]
for j in range(len(pre_row)):
index = []
try:
value = float(pre_row[j])
try:
float(pre_row[0])
except ValueError as e:
index.append(pre_row[0])
try:
float(table_str[0][j])
except ValueError as e:
index.append(table_str[0][j])
if len(index)>0:
parse_result.append([set(index),value])
except:
continue
return parse_result
return None
def _markdown_to_table(self,markdown_str):
table_str = self._extract_markdown_table(markdown_str)
if table_str:
parse_result = self._table_prase(table_str)
return parse_result
return None
def _calculate_table_metrics(self, pred_table, true_table):
true_positives = 0
false_positives = 0
false_negatives = 0
# print(f"pred_table:{pred_table}")
# print(f"true_table:{true_table}")
# Convert lists to dictionaries for easier comparison
pred_dict = {tuple(sorted(item[0])): item[1] for item in pred_table}
true_dict = {tuple(sorted(item[0])): item[1] for item in true_table}
# pprint.pprint(f"pred_dict:{pred_dict}")
# pprint.pprint(f"true_dict:{true_dict}")
# Compare predictions with true values
for key, pred_value in pred_dict.items():
if key in true_dict:
true_value = true_dict[key]
if true_value == 0 and abs(pred_value) < 0.05:
true_positives += 1
elif true_value != 0 and abs((pred_value - true_value) / true_value) < 0.05:
true_positives += 1
else:
false_positives += 1
false_negatives += 1
else:
false_positives += 1
# Count false negatives (items in true table but not in predictions)
for key in true_dict:
if key not in pred_dict:
false_negatives += 1
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
return {
'precision': precision,
'recall': recall,
'f1': f1,
'true_positives': true_positives,
'false_positives': false_positives,
'false_negatives': false_negatives
}
def _compute(self, predictions, references):
predictions = "".join(predictions)
references = "".join(references)
# print(predictions)
return self._calculate_table_metrics(self._markdown_to_table(predictions), self._markdown_to_table(references))
def main():
accuracy_metric = Accuracy()
# 计算指标
results = accuracy_metric.compute(
# predictions=["""
# check | Values || | 0 || lap | 0
# """], # 预测的表格
predictions=["""
||view|denial|finger|check|
|--|--|--|--|--|
|tour|9|8|4|7|
|wall|4|9|5|7|
|sex|2|6|3|1|
"""], # 预测的表格
references=["""
|check|lap|
|--|--|
|80|40|
"""], # 参考的表格
)
print(results) # 输出结果
if __name__ == '__main__':
main()