maqiuping59 commited on
Commit
980aa29
·
verified ·
1 Parent(s): 56dcd48

Update metric.py

Browse files
Files changed (1) hide show
  1. metric.py +86 -57
metric.py CHANGED
@@ -2,6 +2,8 @@ import re
2
  import json
3
  import evaluate
4
  import datasets
 
 
5
 
6
  _DESCRIPTION = """
7
  Table evaluation metrics for assessing the matching degree between predicted and reference tables. It calculates the following metrics:
@@ -44,7 +46,19 @@ _CITATION = """
44
  """
45
 
46
 
47
-
 
 
 
 
 
 
 
 
 
 
 
 
48
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
49
  class Accuracy(evaluate.Metric):
50
  def _info(self):
@@ -71,64 +85,65 @@ class Accuracy(evaluate.Metric):
71
 
72
  return None
73
 
74
- def _table_to_dict(self,table_str):
75
- result_dict = {}
76
-
 
77
  table_str = table_str.lstrip("|").rstrip("|")
78
  parts = table_str.split('||')
79
  parts = [part for part in parts if "--" not in part]
 
 
 
 
80
  legends = parts[0].split("|")
 
81
 
82
  rows = len(parts)
83
- if rows == 2:
84
  nums = parts[1].split("|")
85
- for i in range(len(nums)):
86
- result_dict[legends[i]]=float(nums[i])
87
- elif rows >=3:
88
- for i in range(1,rows):
89
- pre_row = parts[i]
90
- pre_row = pre_row.split("|")
91
- label = pre_row[0]
92
- result_dict[label] = {}
93
- for j in range(1,len(pre_row)):
94
- result_dict[label][legends[j-1]] = float(pre_row[j])
95
- else:
96
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- return result_dict
99
- def _markdown_to_dict(self,markdown_str):
 
 
100
  table_str = self._extract_markdown_table(markdown_str)
101
  if table_str:
102
- return self._table_to_dict(table_str)
103
- else:
104
- return None
105
-
106
- def _calculate_table_metrics(self,pred_table, true_table):
107
- true_positives = 0
108
- false_positives = 0
109
- false_negatives = 0
110
-
111
- for key, pred_value in pred_table.items():
112
- if key in true_table:
113
- true_value = true_table[key]
114
- if isinstance(pred_value, dict) and isinstance(true_value, dict):
115
- nested_metrics = self._calculate_table_metrics(pred_value, true_value)
116
- true_positives += nested_metrics['true_positives']
117
- false_positives += nested_metrics['false_positives']
118
- false_negatives += nested_metrics['false_negatives']
119
- elif true_value == 0 and abs(pred_value) < 0.05:
120
- true_positives += 1
121
- elif true_value != 0 and abs((pred_value - true_value) / true_value) < 0.05:
122
- true_positives += 1
123
- else:
124
- false_positives += 1
125
- false_negatives += 1
126
- else:
127
- false_positives += 1
128
-
129
- for key in true_table:
130
- if key not in pred_table:
131
- false_negatives += 1
132
 
133
  precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
134
  recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
@@ -146,22 +161,36 @@ class Accuracy(evaluate.Metric):
146
  def _compute(self, predictions, references):
147
  predictions = "".join(predictions)
148
  references = "".join(references)
149
- return self._calculate_table_metrics(self._markdown_to_dict(predictions), self._markdown_to_dict(references))
 
 
150
 
151
 
152
  def main():
153
  accuracy_metric = Accuracy()
154
 
155
- # 计算指标
156
- results = accuracy_metric.compute(
 
 
 
 
 
 
 
 
 
 
 
 
157
  predictions=["""
158
- | | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 |
159
- """], # 预测的表格
160
  references=["""
161
- | | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 1 | 6 | 7 | 5 | 9 || wage | 1 | 5 | 2 | 8 | 5 |
162
- """], # 参考的表格
163
  )
164
- print(results) # 输出结果
165
 
166
  if __name__ == '__main__':
167
  main()
 
2
  import json
3
  import evaluate
4
  import datasets
5
+ from typing import Set, Tuple, List, Dict, Any
6
+ from dataclasses import dataclass
7
 
8
  _DESCRIPTION = """
9
  Table evaluation metrics for assessing the matching degree between predicted and reference tables. It calculates the following metrics:
 
46
  """
47
 
48
 
49
+ @dataclass(frozen=True)
50
+ class TableCell:
51
+ labels: frozenset[str] # Using frozenset for hashable unordered pair
52
+ value: float
53
+
54
+ def __eq__(self, other):
55
+ if not isinstance(other, TableCell):
56
+ return False
57
+ return self.labels == other.labels and abs(self.value - other.value) < 0.05
58
+
59
+ def __hash__(self):
60
+ return hash((self.labels, round(self.value, 3))) # Round to handle float comparison
61
+
62
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
63
  class Accuracy(evaluate.Metric):
64
  def _info(self):
 
85
 
86
  return None
87
 
88
+ def _table_to_cell_set(self, table_str: str) -> Set[TableCell]:
89
+ """Convert markdown table string to a set of TableCell objects."""
90
+ result_set = set()
91
+
92
  table_str = table_str.lstrip("|").rstrip("|")
93
  parts = table_str.split('||')
94
  parts = [part for part in parts if "--" not in part]
95
+
96
+ if not parts:
97
+ return result_set
98
+
99
  legends = parts[0].split("|")
100
+ legends = [l.strip() for l in legends if l.strip()]
101
 
102
  rows = len(parts)
103
+ if rows == 2: # Single row table - use single label
104
  nums = parts[1].split("|")
105
+ nums = [n.strip() for n in nums if n.strip()]
106
+ for i, num in enumerate(nums):
107
+ try:
108
+ value = float(num)
109
+ # For single row tables, use a single label
110
+ cell = TableCell(frozenset([legends[i]]), value)
111
+ result_set.add(cell)
112
+ except ValueError:
113
+ continue
114
+ elif rows >= 3: # Multi-row table - use label pairs
115
+ for i in range(1, rows):
116
+ row = parts[i].split("|")
117
+ row = [r.strip() for r in row if r.strip()]
118
+ if not row:
119
+ continue
120
+
121
+ row_label = row[0]
122
+ for j, num in enumerate(row[1:], 1):
123
+ if j >= len(legends):
124
+ continue
125
+ try:
126
+ value = float(num)
127
+ # For multi-row tables, use label pairs
128
+ cell = TableCell(frozenset([row_label, legends[j-1]]), value)
129
+ result_set.add(cell)
130
+ except ValueError:
131
+ continue
132
 
133
+ return result_set
134
+
135
+ def _markdown_to_cell_set(self, markdown_str: str) -> Set[TableCell]:
136
+ """Convert markdown string to a set of TableCell objects."""
137
  table_str = self._extract_markdown_table(markdown_str)
138
  if table_str:
139
+ return self._table_to_cell_set(table_str)
140
+ return set()
141
+
142
+ def _calculate_table_metrics(self, pred_cells: Set[TableCell], true_cells: Set[TableCell]) -> Dict[str, Any]:
143
+ """Calculate metrics using cell set comparison."""
144
+ true_positives = len(pred_cells.intersection(true_cells))
145
+ false_positives = len(pred_cells - true_cells)
146
+ false_negatives = len(true_cells - pred_cells)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
149
  recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
 
161
  def _compute(self, predictions, references):
162
  predictions = "".join(predictions)
163
  references = "".join(references)
164
+ pred_cells = self._markdown_to_cell_set(predictions)
165
+ true_cells = self._markdown_to_cell_set(references)
166
+ return self._calculate_table_metrics(pred_cells, true_cells)
167
 
168
 
169
  def main():
170
  accuracy_metric = Accuracy()
171
 
172
+ # Test with different table formats
173
+ # Test 1: Single row table
174
+ results1 = accuracy_metric.compute(
175
+ predictions=["""
176
+ | | value1 | value2 | value3 ||--|--|--|--|| data | 1.01 | 2 | 3 |
177
+ """],
178
+ references=["""
179
+ | | value1 | value2 | value3 ||--|--|--|--|| data | 1 | 2 | 3 |
180
+ """],
181
+ )
182
+ print("Single row table test:", results1)
183
+
184
+ # Test 2: Multi-row table (transposed)
185
+ results2 = accuracy_metric.compute(
186
  predictions=["""
187
+ | | desire | wage ||--|--|--|| lobby | 5.01 | 1 || search | 8 | 5 || band | 7 | 3 || charge | 5 | 8 || chain | 9 | 5 |
188
+ """],
189
  references=["""
190
+ | | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5.01 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 |
191
+ """],
192
  )
193
+ print("Multi-row table test:", results2)
194
 
195
  if __name__ == '__main__':
196
  main()