maqiuping59 commited on
Commit
12556ce
·
verified ·
1 Parent(s): 3baca8d

Update accuracy.py

Browse files
Files changed (1) hide show
  1. accuracy.py +181 -182
accuracy.py CHANGED
@@ -1,182 +1,181 @@
1
- import re
2
- import json
3
- import evaluate
4
- import datasets
5
-
6
- _DESCRIPTION = """
7
- Table evaluation metrics for assessing the matching degree between predicted and reference tables. It calculates the following metrics:
8
-
9
- 1. Precision: The ratio of correctly predicted cells to the total number of cells in the predicted table
10
- 2. Recall: The ratio of correctly predicted cells to the total number of cells in the reference table
11
- 3. F1 Score: The harmonic mean of precision and recall
12
-
13
- These metrics help evaluate the accuracy of table data extraction or generation.
14
- """
15
-
16
-
17
- _KWARGS_DESCRIPTION = """
18
- Args:
19
- predictions (`str`): Predicted table in Markdown format.
20
- references (`str`): Reference table in Markdown format.
21
-
22
- Returns:
23
- dict: A dictionary containing the following metrics:
24
- - precision (`float`): Precision score, range [0,1]
25
- - recall (`float`): Recall score, range [0,1]
26
- - f1 (`float`): F1 score, range [0,1]
27
- - true_positives (`int`): Number of correctly predicted cells
28
- - false_positives (`int`): Number of incorrectly predicted cells
29
- - false_negatives (`int`): Number of cells that were not predicted
30
-
31
- Examples:
32
- >>> accuracy_metric = evaluate.load("accuracy")
33
- >>> results = accuracy_metric.compute(
34
- ... predictions="| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 |",
35
- ... references="| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 1 | 6 | 7 | 5 | 9 || wage | 1 | 5 | 2 | 8 | 5 |"
36
- ... )
37
- >>> print(results)
38
- {'precision': 0.7, 'recall': 0.7, 'f1': 0.7, 'true_positives': 7, 'false_positives': 3, 'false_negatives': 3}
39
- """
40
-
41
-
42
- _CITATION = """
43
- @article{scikit-learn,
44
- title={Scikit-learn: Machine Learning in {P}ython},
45
- author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V.
46
- and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P.
47
- and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and
48
- Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
49
- journal={Journal of Machine Learning Research},
50
- volume={12},
51
- pages={2825--2830},
52
- year={2011}
53
- }
54
- """
55
-
56
-
57
-
58
- @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
59
- class Accuracy(evaluate.Metric):
60
- def _info(self):
61
- return evaluate.MetricInfo(
62
- description=_DESCRIPTION,
63
- citation=_CITATION,
64
- inputs_description=_KWARGS_DESCRIPTION,
65
- features=datasets.Features(
66
- {
67
- "predictions": datasets.Value("string"),
68
- "references": datasets.Value("string"),
69
- }
70
- ),
71
- reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"],
72
- )
73
- def _extract_markdown_table(self,text):
74
- text = text.replace('\n', '')
75
- text = text.replace(" ","")
76
- pattern = r'\|(?:[^|]+\|)+[^|]+\|'
77
- matches = re.findall(pattern, text)
78
-
79
- if matches:
80
- return ''.join(matches)
81
-
82
- return None
83
-
84
- def _table_to_dict(self,table_str):
85
- result_dict = {}
86
-
87
- table_str = table_str.lstrip("|").rstrip("|")
88
- parts = table_str.split('||')
89
- parts = [part for part in parts if "--" not in part]
90
- legends = parts[0].split("|")
91
-
92
- rows = len(parts)
93
- if rows == 2:
94
- nums = parts[1].split("|")
95
- for i in range(len(nums)):
96
- result_dict[legends[i]]=float(nums[i])
97
- elif rows >=3:
98
- for i in range(1,rows):
99
- pre_row = parts[i]
100
- pre_row = pre_row.split("|")
101
- label = pre_row[0]
102
- result_dict[label] = {}
103
- for j in range(1,len(pre_row)):
104
- result_dict[label][legends[j-1]] = float(pre_row[j])
105
- else:
106
- return None
107
-
108
- return result_dict
109
- def _markdown_to_dict(self,markdown_str):
110
- table_str = self._extract_markdown_table(markdown_str)
111
- if table_str:
112
- return self._table_to_dict(table_str)
113
- else:
114
- return None
115
-
116
- def _calculate_table_metrics(self,pred_table, true_table):
117
- true_positives = 0
118
- false_positives = 0
119
- false_negatives = 0
120
-
121
- # 遍历预测表格的所有键值对
122
- for key, pred_value in pred_table.items():
123
- if key in true_table:
124
- true_value = true_table[key]
125
- if isinstance(pred_value, dict) and isinstance(true_value, dict):
126
- nested_metrics = self._calculate_table_metrics(pred_value, true_value)
127
- true_positives += nested_metrics['true_positives']
128
- false_positives += nested_metrics['false_positives']
129
- false_negatives += nested_metrics['false_negatives']
130
- # 如果值相等
131
- elif pred_value == true_value:
132
- true_positives += 1
133
- else:
134
- false_positives += 1
135
- false_negatives += 1
136
- else:
137
- false_positives += 1
138
-
139
- # 计算未匹配的真实值
140
- for key in true_table:
141
- if key not in pred_table:
142
- false_negatives += 1
143
-
144
- # 计算指标
145
- precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
146
- recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
147
- f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
148
-
149
- return {
150
- 'precision': precision,
151
- 'recall': recall,
152
- 'f1': f1,
153
- 'true_positives': true_positives,
154
- 'false_positives': false_positives,
155
- 'false_negatives': false_negatives
156
- }
157
-
158
- def _compute(self, predictions, references):
159
- predictions = "".join(predictions)
160
- references = "".join(references)
161
- return self._calculate_table_metrics(self._markdown_to_dict(predictions), self._markdown_to_dict(references))
162
-
163
-
164
- def main():
165
- accuracy_metric = Accuracy()
166
-
167
- # 计算指标
168
- results = accuracy_metric.compute(
169
- predictions=["""
170
- | | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 |
171
- """], # 预测的表格
172
- references=["""
173
- | | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 1 | 6 | 7 | 5 | 9 || wage | 1 | 5 | 2 | 8 | 5 |
174
- """], # 参考的表格
175
- )
176
- print(results) # 输出结果
177
-
178
- if __name__ == '__main__':
179
- main()
180
-
181
-
182
-
 
1
+ import re
2
+ import json
3
+ import evaluate
4
+ import datasets
5
+
6
+ _DESCRIPTION = """
7
+ Table evaluation metrics for assessing the matching degree between predicted and reference tables. It calculates the following metrics:
8
+
9
+ 1. Precision: The ratio of correctly predicted cells to the total number of cells in the predicted table
10
+ 2. Recall: The ratio of correctly predicted cells to the total number of cells in the reference table
11
+ 3. F1 Score: The harmonic mean of precision and recall
12
+
13
+ These metrics help evaluate the accuracy of table data extraction or generation.
14
+ """
15
+
16
+
17
+ _KWARGS_DESCRIPTION = """
18
+ Args:
19
+ predictions (`str`): Predicted table in Markdown format.
20
+ references (`str`): Reference table in Markdown format.
21
+
22
+ Returns:
23
+ dict: A dictionary containing the following metrics:
24
+ - precision (`float`): Precision score, range [0,1]
25
+ - recall (`float`): Recall score, range [0,1]
26
+ - f1 (`float`): F1 score, range [0,1]
27
+ - true_positives (`int`): Number of correctly predicted cells
28
+ - false_positives (`int`): Number of incorrectly predicted cells
29
+ - false_negatives (`int`): Number of cells that were not predicted
30
+
31
+ Examples:
32
+ >>> accuracy_metric = evaluate.load("accuracy")
33
+ >>> results = accuracy_metric.compute(
34
+ ... predictions="| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 |",
35
+ ... references="| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 1 | 6 | 7 | 5 | 9 || wage | 1 | 5 | 2 | 8 | 5 |"
36
+ ... )
37
+ >>> print(results)
38
+ {'precision': 0.7, 'recall': 0.7, 'f1': 0.7, 'true_positives': 7, 'false_positives': 3, 'false_negatives': 3}
39
+ """
40
+
41
+
42
+ _CITATION = """
43
+ @article{scikit-learn,
44
+ title={Scikit-learn: Machine Learning in {P}ython},
45
+ author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V.
46
+ and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P.
47
+ and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and
48
+ Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
49
+ journal={Journal of Machine Learning Research},
50
+ volume={12},
51
+ pages={2825--2830},
52
+ year={2011}
53
+ }
54
+ """
55
+
56
+
57
+
58
+ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
59
+ class Accuracy(evaluate.Metric):
60
+ def _info(self):
61
+ return evaluate.MetricInfo(
62
+ description=_DESCRIPTION,
63
+ citation=_CITATION,
64
+ inputs_description=_KWARGS_DESCRIPTION,
65
+ features=datasets.Features(
66
+ {
67
+ "predictions": datasets.Value("string"),
68
+ "references": datasets.Value("string"),
69
+ }
70
+ ),
71
+ reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"],
72
+ )
73
+ def _extract_markdown_table(self,text):
74
+ text = text.replace('\n', '')
75
+ text = text.replace(" ","")
76
+ pattern = r'\|(?:[^|]+\|)+[^|]+\|'
77
+ matches = re.findall(pattern, text)
78
+
79
+ if matches:
80
+ return ''.join(matches)
81
+
82
+ return None
83
+
84
+ def _table_to_dict(self,table_str):
85
+ result_dict = {}
86
+
87
+ table_str = table_str.lstrip("|").rstrip("|")
88
+ parts = table_str.split('||')
89
+ parts = [part for part in parts if "--" not in part]
90
+ legends = parts[0].split("|")
91
+
92
+ rows = len(parts)
93
+ if rows == 2:
94
+ nums = parts[1].split("|")
95
+ for i in range(len(nums)):
96
+ result_dict[legends[i]]=float(nums[i])
97
+ elif rows >=3:
98
+ for i in range(1,rows):
99
+ pre_row = parts[i]
100
+ pre_row = pre_row.split("|")
101
+ label = pre_row[0]
102
+ result_dict[label] = {}
103
+ for j in range(1,len(pre_row)):
104
+ result_dict[label][legends[j-1]] = float(pre_row[j])
105
+ else:
106
+ return None
107
+
108
+ return result_dict
109
+ def _markdown_to_dict(self,markdown_str):
110
+ table_str = self._extract_markdown_table(markdown_str)
111
+ if table_str:
112
+ return self._table_to_dict(table_str)
113
+ else:
114
+ return None
115
+
116
+ def _calculate_table_metrics(self,pred_table, true_table):
117
+ true_positives = 0
118
+ false_positives = 0
119
+ false_negatives = 0
120
+
121
+ for key, pred_value in pred_table.items():
122
+ if key in true_table:
123
+ true_value = true_table[key]
124
+ if isinstance(pred_value, dict) and isinstance(true_value, dict):
125
+ nested_metrics = self._calculate_table_metrics(pred_value, true_value)
126
+ true_positives += nested_metrics['true_positives']
127
+ false_positives += nested_metrics['false_positives']
128
+ false_negatives += nested_metrics['false_negatives']
129
+
130
+ elif pred_value == true_value:
131
+ true_positives += 1
132
+ else:
133
+ false_positives += 1
134
+ false_negatives += 1
135
+ else:
136
+ false_positives += 1
137
+
138
+
139
+ for key in true_table:
140
+ if key not in pred_table:
141
+ false_negatives += 1
142
+
143
+
144
+ precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
145
+ recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
146
+ f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
147
+
148
+ return {
149
+ 'precision': precision,
150
+ 'recall': recall,
151
+ 'f1': f1,
152
+ 'true_positives': true_positives,
153
+ 'false_positives': false_positives,
154
+ 'false_negatives': false_negatives
155
+ }
156
+
157
+ def _compute(self, predictions, references):
158
+ predictions = "".join(predictions)
159
+ references = "".join(references)
160
+ return self._calculate_table_metrics(self._markdown_to_dict(predictions), self._markdown_to_dict(references))
161
+
162
+
163
+ def main():
164
+ accuracy_metric = Accuracy()
165
+
166
+
167
+ results = accuracy_metric.compute(
168
+ predictions=["""
169
+ | | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 |
170
+ """],
171
+ references=["""
172
+ | | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 1 | 6 | 7 | 5 | 9 || wage | 1 | 5 | 2 | 8 | 5 |
173
+ """],
174
+ )
175
+ print(results)
176
+
177
+ if __name__ == '__main__':
178
+ main()
179
+
180
+
181
+