File size: 6,576 Bytes
12556ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import re
import json
import evaluate
import datasets

_DESCRIPTION = """
Table evaluation metrics for assessing the matching degree between predicted and reference tables. It calculates the following metrics:

1. Precision: The ratio of correctly predicted cells to the total number of cells in the predicted table
2. Recall: The ratio of correctly predicted cells to the total number of cells in the reference table
3. F1 Score: The harmonic mean of precision and recall

These metrics help evaluate the accuracy of table data extraction or generation.
"""


_KWARGS_DESCRIPTION = """
Args:
    predictions (`str`): Predicted table in Markdown format.
    references (`str`): Reference table in Markdown format.

Returns:
    dict: A dictionary containing the following metrics:
        - precision (`float`): Precision score, range [0,1]
        - recall (`float`): Recall score, range [0,1]
        - f1 (`float`): F1 score, range [0,1]
        - true_positives (`int`): Number of correctly predicted cells
        - false_positives (`int`): Number of incorrectly predicted cells
        - false_negatives (`int`): Number of cells that were not predicted

Examples:
    >>> accuracy_metric = evaluate.load("accuracy")
    >>> results = accuracy_metric.compute(
    ...     predictions="|  | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 |",
    ...     references="|  | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 1 | 6 | 7 | 5 | 9 || wage | 1 | 5 | 2 | 8 | 5 |"
    ... )
    >>> print(results)
    {'precision': 0.7, 'recall': 0.7, 'f1': 0.7, 'true_positives': 7, 'false_positives': 3, 'false_negatives': 3}
"""


_CITATION = """
@article{scikit-learn,
  title={Scikit-learn: Machine Learning in {P}ython},
  author={Pedregosa, F. and Varoquaux, G. and Gramfort, A. and Michel, V.
         and Thirion, B. and Grisel, O. and Blondel, M. and Prettenhofer, P.
         and Weiss, R. and Dubourg, V. and Vanderplas, J. and Passos, A. and
         Cournapeau, D. and Brucher, M. and Perrot, M. and Duchesnay, E.},
  journal={Journal of Machine Learning Research},
  volume={12},
  pages={2825--2830},
  year={2011}
}
"""


    
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class Accuracy(evaluate.Metric):
    def _info(self):
        return evaluate.MetricInfo(
            description=_DESCRIPTION,
            citation=_CITATION,
            inputs_description=_KWARGS_DESCRIPTION,
            features=datasets.Features(
                {
                    "predictions": datasets.Value("string"),
                    "references": datasets.Value("string"),
                }
            ),
            reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"],
        )
    def _extract_markdown_table(self,text):
        text = text.replace('\n', '')
        text = text.replace(" ","")
        pattern = r'\|(?:[^|]+\|)+[^|]+\|'
        matches = re.findall(pattern, text)
        
        if matches:
            return ''.join(matches)
        
        return None
    
    def _table_to_dict(self,table_str):
        result_dict = {}

        table_str = table_str.lstrip("|").rstrip("|")
        parts = table_str.split('||')
        parts = [part for part in parts if "--" not in part]
        legends = parts[0].split("|")

        rows = len(parts)
        if rows == 2:
            nums = parts[1].split("|")
            for i in range(len(nums)):
                result_dict[legends[i]]=float(nums[i])
        elif rows >=3:
            for i in range(1,rows):
                pre_row = parts[i]
                pre_row = pre_row.split("|")
                label = pre_row[0]
                result_dict[label] = {}
                for j in range(1,len(pre_row)):
                    result_dict[label][legends[j-1]] = float(pre_row[j])
        else:
            return None
        
        return result_dict
    def _markdown_to_dict(self,markdown_str):
        table_str = self._extract_markdown_table(markdown_str)
        if table_str:
            return self._table_to_dict(table_str)
        else:
            return None
    
    def _calculate_table_metrics(self,pred_table, true_table):
        true_positives = 0
        false_positives = 0
        false_negatives = 0

        for key, pred_value in pred_table.items():
            if key in true_table:
                true_value = true_table[key]
                if isinstance(pred_value, dict) and isinstance(true_value, dict):
                    nested_metrics = self._calculate_table_metrics(pred_value, true_value)
                    true_positives += nested_metrics['true_positives']
                    false_positives += nested_metrics['false_positives']
                    false_negatives += nested_metrics['false_negatives']

                elif pred_value == true_value:
                    true_positives += 1
                else:
                    false_positives += 1
                    false_negatives += 1
            else:
                false_positives += 1


        for key in true_table:
            if key not in pred_table:
                false_negatives += 1


        precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
        recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

        return {
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'true_positives': true_positives,
            'false_positives': false_positives,
            'false_negatives': false_negatives
        }

    def _compute(self, predictions, references):
        predictions = "".join(predictions)
        references = "".join(references)
        return self._calculate_table_metrics(self._markdown_to_dict(predictions), self._markdown_to_dict(references))
    

def main():
    accuracy_metric = Accuracy()


    results = accuracy_metric.compute(
        predictions=["""
|  | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 5 | 8 | 7 | 5 | 9 || wage | 1 | 5 | 3 | 8 | 5 |
"""], 
        references=["""
|  | lobby | search | band | charge | chain ||--|--|--|--|--|--|| desire | 1 | 6 | 7 | 5 | 9 || wage | 1 | 5 | 2 | 8 | 5 |
"""],   
    )
    print(results) 

if __name__ == '__main__':
    main()