maqiuping59 commited on
Commit
6ee5ac2
·
verified ·
1 Parent(s): 980c078

Upload 3 files

Browse files
Files changed (3) hide show
  1. accuracy.py +182 -0
  2. app.py +6 -0
  3. requirements.txt +0 -0
accuracy.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import evaluate
4
+ import datasets
5
+
6
+ _DESCRIPTION = """
7
+ Table evaluation metrics for assessing the matching degree between predicted and reference tables. It calculates the following metrics:
8
+
9
+ 1. Precision: The ratio of correctly predicted cells to the total number of cells in the predicted table
10
+ 2. Recall: The ratio of correctly predicted cells to the total number of cells in the reference table
11
+ 3. F1 Score: The harmonic mean of precision and recall
12
+
13
+ These metrics help evaluate the accuracy of table data extraction or generation.
14
+ """
15
+
16
+
17
+ _KWARGS_DESCRIPTION = """
18
+ Args:
19
+ predictions (`str`): Predicted table in Markdown format.
20
+ references (`str`): Reference table in Markdown format.
21
+
22
+ Returns:
23
+ dict: A dictionary containing the following metrics:
24
+ - precision (`float`): Precision score, range [0,1]
25
+ - recall (`float`): Recall score, range [0,1]
26
+ - f1 (`float`): F1 score, range [0,1]
27
+ - true_positives (`int`): Number of correctly predicted cells
28
+ - false_positives (`int`): Number of incorrectly predicted cells
29
+ - false_negatives (`int`): Number of cells that were not predicted
30
+
31
+ Examples:
32
+ >>> accuracy_metric = evaluate.load("accuracy")
33
+ >>> results = accuracy_metric.compute(
34
+ ... predictions="| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 |",
35
+ ... references="| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 1 | 6 | 7 | 5 | 9 || wage | 1 | 5 | 2 | 8 | 5 |"
36
+ ... )
37
+ >>> print(results)
38
+ {'precision': 0.7, 'recall': 0.7, 'f1': 0.7, 'true_positives': 7, 'false_positives': 3, 'false_negatives': 3}
39
+ """
40
+
41
+
42
+ _CITATION = """
43
+ @article{scikit-learn,
44
+ title={Scikit-learn: Machine Learning in {P}ython},
45
+ author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V.
46
+ and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P.
47
+ and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and
48
+ Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
49
+ journal={Journal of Machine Learning Research},
50
+ volume={12},
51
+ pages={2825--2830},
52
+ year={2011}
53
+ }
54
+ """
55
+
56
+
57
+
58
+ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
59
+ class Accuracy(evaluate.Metric):
60
+ def _info(self):
61
+ return evaluate.MetricInfo(
62
+ description=_DESCRIPTION,
63
+ citation=_CITATION,
64
+ inputs_description=_KWARGS_DESCRIPTION,
65
+ features=datasets.Features(
66
+ {
67
+ "predictions": datasets.Value("string"),
68
+ "references": datasets.Value("string"),
69
+ }
70
+ ),
71
+ reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"],
72
+ )
73
+ def _extract_markdown_table(self,text):
74
+ text = text.replace('\n', '')
75
+ text = text.replace(" ","")
76
+ pattern = r'\|(?:[^|]+\|)+[^|]+\|'
77
+ matches = re.findall(pattern, text)
78
+
79
+ if matches:
80
+ return ''.join(matches)
81
+
82
+ return None
83
+
84
+ def _table_to_dict(self,table_str):
85
+ result_dict = {}
86
+
87
+ table_str = table_str.lstrip("|").rstrip("|")
88
+ parts = table_str.split('||')
89
+ parts = [part for part in parts if "--" not in part]
90
+ legends = parts[0].split("|")
91
+
92
+ rows = len(parts)
93
+ if rows == 2:
94
+ nums = parts[1].split("|")
95
+ for i in range(len(nums)):
96
+ result_dict[legends[i]]=float(nums[i])
97
+ elif rows >=3:
98
+ for i in range(1,rows):
99
+ pre_row = parts[i]
100
+ pre_row = pre_row.split("|")
101
+ label = pre_row[0]
102
+ result_dict[label] = {}
103
+ for j in range(1,len(pre_row)):
104
+ result_dict[label][legends[j-1]] = float(pre_row[j])
105
+ else:
106
+ return None
107
+
108
+ return result_dict
109
+ def _markdown_to_dict(self,markdown_str):
110
+ table_str = self._extract_markdown_table(markdown_str)
111
+ if table_str:
112
+ return self._table_to_dict(table_str)
113
+ else:
114
+ return None
115
+
116
+ def _calculate_table_metrics(self,pred_table, true_table):
117
+ true_positives = 0
118
+ false_positives = 0
119
+ false_negatives = 0
120
+
121
+ # 遍历预测表格的所有键值对
122
+ for key, pred_value in pred_table.items():
123
+ if key in true_table:
124
+ true_value = true_table[key]
125
+ if isinstance(pred_value, dict) and isinstance(true_value, dict):
126
+ nested_metrics = self._calculate_table_metrics(pred_value, true_value)
127
+ true_positives += nested_metrics['true_positives']
128
+ false_positives += nested_metrics['false_positives']
129
+ false_negatives += nested_metrics['false_negatives']
130
+ # 如果值相等
131
+ elif pred_value == true_value:
132
+ true_positives += 1
133
+ else:
134
+ false_positives += 1
135
+ false_negatives += 1
136
+ else:
137
+ false_positives += 1
138
+
139
+ # 计算未匹配的真实值
140
+ for key in true_table:
141
+ if key not in pred_table:
142
+ false_negatives += 1
143
+
144
+ # 计算指标
145
+ precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
146
+ recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
147
+ f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
148
+
149
+ return {
150
+ 'precision': precision,
151
+ 'recall': recall,
152
+ 'f1': f1,
153
+ 'true_positives': true_positives,
154
+ 'false_positives': false_positives,
155
+ 'false_negatives': false_negatives
156
+ }
157
+
158
+ def _compute(self, predictions, references):
159
+ predictions = "".join(predictions)
160
+ references = "".join(references)
161
+ return self._calculate_table_metrics(self._markdown_to_dict(predictions), self._markdown_to_dict(references))
162
+
163
+
164
+ def main():
165
+ accuracy_metric = Accuracy()
166
+
167
+ # 计算指标
168
+ results = accuracy_metric.compute(
169
+ predictions=["""
170
+ | | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 |
171
+ """], # 预测的表格
172
+ references=["""
173
+ | | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 1 | 6 | 7 | 5 | 9 || wage | 1 | 5 | 2 | 8 | 5 |
174
+ """], # 参考的表格
175
+ )
176
+ print(results) # 输出结果
177
+
178
+ if __name__ == '__main__':
179
+ main()
180
+
181
+
182
+
app.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import evaluate
2
+ from evaluate.utils import launch_gradio_widget
3
+
4
+
5
+ module = evaluate.load("maqiuping59/table_markdown")
6
+ launch_gradio_widget(module)
requirements.txt ADDED
File without changes