Spaces:
Running
Running
File size: 6,649 Bytes
12556ce d81910b 12556ce d81910b 12556ce 41b45e8 12556ce d81910b 12556ce d81910b 12556ce d81910b 12556ce d81910b 41b45e8 d81910b 41b45e8 980aa29 c2d3ba8 d81910b 41b45e8 12556ce d81910b 41b45e8 12556ce c2d3ba8 d81910b 12556ce 41b45e8 d81910b 12556ce d81910b 41b45e8 12556ce d81910b 41b45e8 12556ce 41b45e8 12556ce |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import re
import json
import evaluate
import datasets
import pprint
_DESCRIPTION = """
Table evaluation metrics for assessing the matching degree between predicted and reference tables. It calculates the following metrics:
1. Precision: The ratio of correctly predicted cells to the total number of cells in the predicted table
2. Recall: The ratio of correctly predicted cells to the total number of cells in the reference table
3. F1 Score: The harmonic mean of precision and recall
These metrics help evaluate the accuracy of table data extraction or generation.
"""
_KWARGS_DESCRIPTION = """
Args:
predictions (`str`): Predicted table in Markdown format.
references (`str`): Reference table in Markdown format.
Returns:
dict: A dictionary containing the following metrics:
- precision (`float`): Precision score, range [0,1]
- recall (`float`): Recall score, range [0,1]
- f1 (`float`): F1 score, range [0,1]
- true_positives (`int`): Number of correctly predicted cells
- false_positives (`int`): Number of incorrectly predicted cells
- false_negatives (`int`): Number of cells that were not predicted
Examples:
>>> accuracy_metric = evaluate.load("accuracy")
>>> results = accuracy_metric.compute(
... predictions="| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 |",
... references="| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 1 | 6 | 7 | 5 | 9 || wage | 1 | 5 | 2 | 8 | 5 |"
... )
>>> print(results)
{'precision': 0.7, 'recall': 0.7, 'f1': 0.7, 'true_positives': 7, 'false_positives': 3, 'false_negatives': 3}
"""
_CITATION = """
"""
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class Accuracy(evaluate.Metric):
def _info(self):
return evaluate.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Value("string"),
"references": datasets.Value("string"),
}
),
reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"],
)
def _extract_markdown_table(self,text):
text = text.replace('\n', '')
text = text.replace(" ","")
if text.startswith("||"):
text = "|header"+text[1:]
pattern = r'\|(?:[^|]+\|)+[^|]+\|'
matches = re.findall(pattern, text)
if matches:
return ''.join(matches)
return None
def _table_prase(self,table_str):
print(table_str)
table_str = table_str.split("||")
table_str = [item.strip("|").split("|") for item in table_str]
parse_result = []
for i in range(len(table_str)):
pre_row = table_str[i]
for j in range(len(pre_row)):
index = []
try:
value = float(pre_row[j])
try:
float(pre_row[0])
except ValueError as e:
index.append(pre_row[0])
try:
float(table_str[0][j])
except ValueError as e:
index.append(table_str[0][j])
if len(index)>0:
parse_result.append([set(index),value])
except:
continue
return parse_result
return None
def _markdown_to_table(self,markdown_str):
table_str = self._extract_markdown_table(markdown_str)
if table_str:
parse_result = self._table_prase(table_str)
return parse_result
return None
def _calculate_table_metrics(self, pred_table, true_table):
true_positives = 0
false_positives = 0
false_negatives = 0
# print(f"pred_table:{pred_table}")
# print(f"true_table:{true_table}")
# Convert lists to dictionaries for easier comparison
pred_dict = {tuple(sorted(item[0])): item[1] for item in pred_table}
true_dict = {tuple(sorted(item[0])): item[1] for item in true_table}
# pprint.pprint(f"pred_dict:{pred_dict}")
# pprint.pprint(f"true_dict:{true_dict}")
# Compare predictions with true values
for key, pred_value in pred_dict.items():
if key in true_dict:
true_value = true_dict[key]
if true_value == 0 and abs(pred_value) < 0.05:
true_positives += 1
elif true_value != 0 and abs((pred_value - true_value) / true_value) < 0.05:
true_positives += 1
else:
false_positives += 1
false_negatives += 1
else:
false_positives += 1
# Count false negatives (items in true table but not in predictions)
for key in true_dict:
if key not in pred_dict:
false_negatives += 1
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
return {
'precision': precision,
'recall': recall,
'f1': f1,
'true_positives': true_positives,
'false_positives': false_positives,
'false_negatives': false_negatives
}
def _compute(self, predictions, references):
predictions = "".join(predictions)
references = "".join(references)
# print(predictions)
return self._calculate_table_metrics(self._markdown_to_table(predictions), self._markdown_to_table(references))
def main():
accuracy_metric = Accuracy()
# 计算指标
results = accuracy_metric.compute(
# predictions=["""
# check | Values || | 0 || lap | 0
# """], # 预测的表格
predictions=["""
||view|denial|finger|check|
|--|--|--|--|--|
|tour|9|8|4|7|
|wall|4|9|5|7|
|sex|2|6|3|1|
"""], # 预测的表格
references=["""
|check|lap|
|--|--|
|80|40|
"""], # 参考的表格
)
print(results) # 输出结果
if __name__ == '__main__':
main()
|