maqiuping59 commited on
Commit
41b45e8
·
verified ·
1 Parent(s): 980aa29

Update metric.py

Browse files
Files changed (1) hide show
  1. metric.py +268 -86
metric.py CHANGED
@@ -1,9 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import re
2
  import json
3
  import evaluate
4
  import datasets
5
- from typing import Set, Tuple, List, Dict, Any
6
- from dataclasses import dataclass
7
 
8
  _DESCRIPTION = """
9
  Table evaluation metrics for assessing the matching degree between predicted and reference tables. It calculates the following metrics:
@@ -42,23 +238,21 @@ Examples:
42
 
43
 
44
  _CITATION = """
45
-
 
 
 
 
 
 
 
 
 
 
46
  """
47
 
48
 
49
- @dataclass(frozen=True)
50
- class TableCell:
51
- labels: frozenset[str] # Using frozenset for hashable unordered pair
52
- value: float
53
-
54
- def __eq__(self, other):
55
- if not isinstance(other, TableCell):
56
- return False
57
- return self.labels == other.labels and abs(self.value - other.value) < 0.05
58
-
59
- def __hash__(self):
60
- return hash((self.labels, round(self.value, 3))) # Round to handle float comparison
61
-
62
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
63
  class Accuracy(evaluate.Metric):
64
  def _info(self):
@@ -85,66 +279,67 @@ class Accuracy(evaluate.Metric):
85
 
86
  return None
87
 
88
- def _table_to_cell_set(self, table_str: str) -> Set[TableCell]:
89
- """Convert markdown table string to a set of TableCell objects."""
90
- result_set = set()
91
-
92
  table_str = table_str.lstrip("|").rstrip("|")
93
  parts = table_str.split('||')
94
  parts = [part for part in parts if "--" not in part]
95
-
96
- if not parts:
97
- return result_set
98
-
99
  legends = parts[0].split("|")
100
- legends = [l.strip() for l in legends if l.strip()]
101
 
102
  rows = len(parts)
103
- if rows == 2: # Single row table - use single label
104
  nums = parts[1].split("|")
105
- nums = [n.strip() for n in nums if n.strip()]
106
- for i, num in enumerate(nums):
107
- try:
108
- value = float(num)
109
- # For single row tables, use a single label
110
- cell = TableCell(frozenset([legends[i]]), value)
111
- result_set.add(cell)
112
- except ValueError:
113
- continue
114
- elif rows >= 3: # Multi-row table - use label pairs
115
- for i in range(1, rows):
116
- row = parts[i].split("|")
117
- row = [r.strip() for r in row if r.strip()]
118
- if not row:
119
- continue
120
-
121
- row_label = row[0]
122
- for j, num in enumerate(row[1:], 1):
123
- if j >= len(legends):
124
- continue
125
- try:
126
- value = float(num)
127
- # For multi-row tables, use label pairs
128
- cell = TableCell(frozenset([row_label, legends[j-1]]), value)
129
- result_set.add(cell)
130
- except ValueError:
131
- continue
132
 
133
- return result_set
134
-
135
- def _markdown_to_cell_set(self, markdown_str: str) -> Set[TableCell]:
136
- """Convert markdown string to a set of TableCell objects."""
137
  table_str = self._extract_markdown_table(markdown_str)
138
  if table_str:
139
- return self._table_to_cell_set(table_str)
140
- return set()
 
 
 
 
 
 
141
 
142
- def _calculate_table_metrics(self, pred_cells: Set[TableCell], true_cells: Set[TableCell]) -> Dict[str, Any]:
143
- """Calculate metrics using cell set comparison."""
144
- true_positives = len(pred_cells.intersection(true_cells))
145
- false_positives = len(pred_cells - true_cells)
146
- false_negatives = len(true_cells - pred_cells)
 
 
 
 
 
 
 
 
 
 
 
 
147
 
 
 
 
 
 
 
148
  precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
149
  recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
150
  f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
@@ -161,39 +356,26 @@ class Accuracy(evaluate.Metric):
161
  def _compute(self, predictions, references):
162
  predictions = "".join(predictions)
163
  references = "".join(references)
164
- pred_cells = self._markdown_to_cell_set(predictions)
165
- true_cells = self._markdown_to_cell_set(references)
166
- return self._calculate_table_metrics(pred_cells, true_cells)
167
 
168
 
169
  def main():
170
  accuracy_metric = Accuracy()
171
 
172
- # Test with different table formats
173
- # Test 1: Single row table
174
- results1 = accuracy_metric.compute(
175
- predictions=["""
176
- | | value1 | value2 | value3 ||--|--|--|--|| data | 1.01 | 2 | 3 |
177
- """],
178
- references=["""
179
- | | value1 | value2 | value3 ||--|--|--|--|| data | 1 | 2 | 3 |
180
- """],
181
- )
182
- print("Single row table test:", results1)
183
-
184
- # Test 2: Multi-row table (transposed)
185
- results2 = accuracy_metric.compute(
186
  predictions=["""
187
- | | desire | wage ||--|--|--|| lobby | 5.01 | 1 || search | 8 | 5 || band | 7 | 3 || charge | 5 | 8 || chain | 9 | 5 |
188
- """],
189
  references=["""
190
- | | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5.01 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 |
191
- """],
192
  )
193
- print("Multi-row table test:", results2)
194
 
195
  if __name__ == '__main__':
196
  main()
197
 
198
 
199
 
 
 
1
+ # import re
2
+ # import json
3
+ # import evaluate
4
+ # import datasets
5
+ # from typing import Set, Tuple, List, Dict, Any
6
+ # from dataclasses import dataclass
7
+
8
+ # _DESCRIPTION = """
9
+ # Table evaluation metrics for assessing the matching degree between predicted and reference tables. It calculates the following metrics:
10
+
11
+ # 1. Precision: The ratio of correctly predicted cells to the total number of cells in the predicted table
12
+ # 2. Recall: The ratio of correctly predicted cells to the total number of cells in the reference table
13
+ # 3. F1 Score: The harmonic mean of precision and recall
14
+
15
+ # These metrics help evaluate the accuracy of table data extraction or generation.
16
+ # """
17
+
18
+
19
+ # _KWARGS_DESCRIPTION = """
20
+ # Args:
21
+ # predictions (`str`): Predicted table in Markdown format.
22
+ # references (`str`): Reference table in Markdown format.
23
+
24
+ # Returns:
25
+ # dict: A dictionary containing the following metrics:
26
+ # - precision (`float`): Precision score, range [0,1]
27
+ # - recall (`float`): Recall score, range [0,1]
28
+ # - f1 (`float`): F1 score, range [0,1]
29
+ # - true_positives (`int`): Number of correctly predicted cells
30
+ # - false_positives (`int`): Number of incorrectly predicted cells
31
+ # - false_negatives (`int`): Number of cells that were not predicted
32
+
33
+ # Examples:
34
+ # >>> accuracy_metric = evaluate.load("accuracy")
35
+ # >>> results = accuracy_metric.compute(
36
+ # ... predictions="| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 |",
37
+ # ... references="| | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 1 | 6 | 7 | 5 | 9 || wage | 1 | 5 | 2 | 8 | 5 |"
38
+ # ... )
39
+ # >>> print(results)
40
+ # {'precision': 0.7, 'recall': 0.7, 'f1': 0.7, 'true_positives': 7, 'false_positives': 3, 'false_negatives': 3}
41
+ # """
42
+
43
+
44
+ # _CITATION = """
45
+
46
+ # """
47
+
48
+
49
+ # @dataclass(frozen=True)
50
+ # class TableCell:
51
+ # labels: frozenset[str] # Using frozenset for hashable unordered pair
52
+ # value: float
53
+
54
+ # def __eq__(self, other):
55
+ # if not isinstance(other, TableCell):
56
+ # return False
57
+ # return self.labels == other.labels and abs(self.value - other.value) < 0.05
58
+
59
+ # def __hash__(self):
60
+ # return hash((self.labels, round(self.value, 3))) # Round to handle float comparison
61
+
62
+ # @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
63
+ # class Accuracy(evaluate.Metric):
64
+ # def _info(self):
65
+ # return evaluate.MetricInfo(
66
+ # description=_DESCRIPTION,
67
+ # citation=_CITATION,
68
+ # inputs_description=_KWARGS_DESCRIPTION,
69
+ # features=datasets.Features(
70
+ # {
71
+ # "predictions": datasets.Value("string"),
72
+ # "references": datasets.Value("string"),
73
+ # }
74
+ # ),
75
+ # reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"],
76
+ # )
77
+ # def _extract_markdown_table(self,text):
78
+ # text = text.replace('\n', '')
79
+ # text = text.replace(" ","")
80
+ # pattern = r'\|(?:[^|]+\|)+[^|]+\|'
81
+ # matches = re.findall(pattern, text)
82
+
83
+ # if matches:
84
+ # return ''.join(matches)
85
+
86
+ # return None
87
+
88
+ # def _table_to_cell_set(self, table_str: str) -> Set[TableCell]:
89
+ # """Convert markdown table string to a set of TableCell objects."""
90
+ # result_set = set()
91
+
92
+ # table_str = table_str.lstrip("|").rstrip("|")
93
+ # parts = table_str.split('||')
94
+ # parts = [part for part in parts if "--" not in part]
95
+
96
+ # if not parts:
97
+ # return result_set
98
+
99
+ # legends = parts[0].split("|")
100
+ # legends = [l.strip() for l in legends if l.strip()]
101
+
102
+ # rows = len(parts)
103
+ # if rows == 2: # Single row table - use single label
104
+ # nums = parts[1].split("|")
105
+ # nums = [n.strip() for n in nums if n.strip()]
106
+ # for i, num in enumerate(nums):
107
+ # try:
108
+ # value = float(num)
109
+ # # For single row tables, use a single label
110
+ # cell = TableCell(frozenset([legends[i]]), value)
111
+ # result_set.add(cell)
112
+ # except ValueError:
113
+ # continue
114
+ # elif rows >= 3: # Multi-row table - use label pairs
115
+ # for i in range(1, rows):
116
+ # row = parts[i].split("|")
117
+ # row = [r.strip() for r in row if r.strip()]
118
+ # if not row:
119
+ # continue
120
+
121
+ # row_label = row[0]
122
+ # for j, num in enumerate(row[1:], 1):
123
+ # if j >= len(legends):
124
+ # continue
125
+ # try:
126
+ # value = float(num)
127
+ # # For multi-row tables, use label pairs
128
+ # cell = TableCell(frozenset([row_label, legends[j-1]]), value)
129
+ # result_set.add(cell)
130
+ # except ValueError:
131
+ # continue
132
+
133
+ # return result_set
134
+
135
+ # def _markdown_to_cell_set(self, markdown_str: str) -> Set[TableCell]:
136
+ # """Convert markdown string to a set of TableCell objects."""
137
+ # table_str = self._extract_markdown_table(markdown_str)
138
+ # if table_str:
139
+ # return self._table_to_cell_set(table_str)
140
+ # return set()
141
+
142
+ # def _calculate_table_metrics(self, pred_cells: Set[TableCell], true_cells: Set[TableCell]) -> Dict[str, Any]:
143
+ # """Calculate metrics using cell set comparison."""
144
+ # true_positives = len(pred_cells.intersection(true_cells))
145
+ # false_positives = len(pred_cells - true_cells)
146
+ # false_negatives = len(true_cells - pred_cells)
147
+
148
+ # precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
149
+ # recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
150
+ # f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
151
+
152
+ # return {
153
+ # 'precision': precision,
154
+ # 'recall': recall,
155
+ # 'f1': f1,
156
+ # 'true_positives': true_positives,
157
+ # 'false_positives': false_positives,
158
+ # 'false_negatives': false_negatives
159
+ # }
160
+
161
+ # def _compute(self, predictions, references):
162
+ # predictions = "".join(predictions)
163
+ # references = "".join(references)
164
+ # pred_cells = self._markdown_to_cell_set(predictions)
165
+ # true_cells = self._markdown_to_cell_set(references)
166
+ # return self._calculate_table_metrics(pred_cells, true_cells)
167
+
168
+
169
+ # def main():
170
+ # accuracy_metric = Accuracy()
171
+
172
+ # # Test with different table formats
173
+ # # Test 1: Single row table
174
+ # results1 = accuracy_metric.compute(
175
+ # predictions=["""
176
+ # | | value1 | value2 | value3 ||--|--|--|--|| data | 1.01 | 2 | 3 |
177
+ # """],
178
+ # references=["""
179
+ # | | value1 | value2 | value3 ||--|--|--|--|| data | 1 | 2 | 3 |
180
+ # """],
181
+ # )
182
+ # print("Single row table test:", results1)
183
+
184
+ # # Test 2: Multi-row table (transposed)
185
+ # results2 = accuracy_metric.compute(
186
+ # predictions=["""
187
+ # | | desire | wage ||--|--|--|| lobby | 5.01 | 1 || search | 8 | 5 || band | 7 | 3 || charge | 5 | 8 || chain | 9 | 5 |
188
+ # """],
189
+ # references=["""
190
+ # | | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5.01 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 |
191
+ # """],
192
+ # )
193
+ # print("Multi-row table test:", results2)
194
+
195
+ # if __name__ == '__main__':
196
+ # main()
197
+
198
+
199
  import re
200
  import json
201
  import evaluate
202
  import datasets
 
 
203
 
204
  _DESCRIPTION = """
205
  Table evaluation metrics for assessing the matching degree between predicted and reference tables. It calculates the following metrics:
 
238
 
239
 
240
  _CITATION = """
241
+ @article{scikit-learn,
242
+ title={Scikit-learn: Machine Learning in {P}ython},
243
+ author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V.
244
+ and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P.
245
+ and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and
246
+ Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
247
+ journal={Journal of Machine Learning Research},
248
+ volume={12},
249
+ pages={2825--2830},
250
+ year={2011}
251
+ }
252
  """
253
 
254
 
255
+
 
 
 
 
 
 
 
 
 
 
 
 
256
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
257
  class Accuracy(evaluate.Metric):
258
  def _info(self):
 
279
 
280
  return None
281
 
282
+ def _table_to_dict(self,table_str):
283
+ result_dict = {}
284
+
 
285
  table_str = table_str.lstrip("|").rstrip("|")
286
  parts = table_str.split('||')
287
  parts = [part for part in parts if "--" not in part]
 
 
 
 
288
  legends = parts[0].split("|")
 
289
 
290
  rows = len(parts)
291
+ if rows == 2:
292
  nums = parts[1].split("|")
293
+ for i in range(len(nums)):
294
+ result_dict[legends[i]]=float(nums[i])
295
+ elif rows >=3:
296
+ for i in range(1,rows):
297
+ pre_row = parts[i]
298
+ pre_row = pre_row.split("|")
299
+ label = pre_row[0]
300
+ result_dict[label] = {}
301
+ for j in range(1,len(pre_row)):
302
+ result_dict[label][legends[j-1]] = float(pre_row[j])
303
+ else:
304
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
+ return result_dict
307
+ def _markdown_to_dict(self,markdown_str):
 
 
308
  table_str = self._extract_markdown_table(markdown_str)
309
  if table_str:
310
+ return self._table_to_dict(table_str)
311
+ else:
312
+ return None
313
+
314
+ def _calculate_table_metrics(self,pred_table, true_table):
315
+ true_positives = 0
316
+ false_positives = 0
317
+ false_negatives = 0
318
 
319
+ # 遍历预测表格的所有键值对
320
+ for key, pred_value in pred_table.items():
321
+ if key in true_table:
322
+ true_value = true_table[key]
323
+ if isinstance(pred_value, dict) and isinstance(true_value, dict):
324
+ nested_metrics = self._calculate_table_metrics(pred_value, true_value)
325
+ true_positives += nested_metrics['true_positives']
326
+ false_positives += nested_metrics['false_positives']
327
+ false_negatives += nested_metrics['false_negatives']
328
+ # 如果值相等
329
+ elif pred_value == true_value:
330
+ true_positives += 1
331
+ else:
332
+ false_positives += 1
333
+ false_negatives += 1
334
+ else:
335
+ false_positives += 1
336
 
337
+ # 计算未匹配的真实值
338
+ for key in true_table:
339
+ if key not in pred_table:
340
+ false_negatives += 1
341
+
342
+ # 计算指标
343
  precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
344
  recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
345
  f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
 
356
  def _compute(self, predictions, references):
357
  predictions = "".join(predictions)
358
  references = "".join(references)
359
+ return self._calculate_table_metrics(self._markdown_to_dict(predictions), self._markdown_to_dict(references))
 
 
360
 
361
 
362
  def main():
363
  accuracy_metric = Accuracy()
364
 
365
+ # 计算指标
366
+ results = accuracy_metric.compute(
 
 
 
 
 
 
 
 
 
 
 
 
367
  predictions=["""
368
+ | | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 |
369
+ """], # 预测的表格
370
  references=["""
371
+ | | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 1 | 6 | 7 | 5 | 9 || wage | 1 | 5 | 2 | 8 | 5 |
372
+ """], # 参考的表格
373
  )
374
+ print(results) # 输出结果
375
 
376
  if __name__ == '__main__':
377
  main()
378
 
379
 
380
 
381
+