File size: 7,026 Bytes
2f22782
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1ad072
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import re 
import os 
from collections import Counter
import json


class Tag:
    def __init__(self, txt_line:str):
        # | file_name | label_type | label_start | label_end | label_text |
        # match = re.match(r'(.+)\t(\w+)\t(\d+)\t(\d+)\t(.+)', txt_line)
        try:
            sep = txt_line.strip().split('\t')
            self.file_id = sep[0]    
            self.type = sep[1]
            self.start = sep[2] # int(sep[2])
            self.end =  sep[3] # int(sep[3])
            self.text = sep[4]
        except:
            raise ValueError('The format of the input line is not correct. Please check the input line format.')

    def get_type(self):
        return self.type

    def get_file_id(self):
        return self.file_id
    
    def __eq__(self, other: 'Tag'):
        # if all file_id, type, start, end,  are the same, return True
        # text is not considered for the comparison
        ck_file_id = self.file_id == other.file_id
        ck_type = self.type == other.type
        ck_start = self.start == other.start
        ck_end = self.end == other.end
        # ck_text = self.text == other.text
        if ck_file_id and ck_type and ck_start and ck_end:
            return True
        else:
            return False
    def __repr__(self):
        return f'<{self.__class__.__name__} {self.file_id:10} {self.type:10} s:{self.start:5} e:{self.end:5} {self.text}>\n'

    def __hash__(self):
        return hash((self.file_id, self.type, self.start, self.end))
    
class Evaluation_answer_txt:
    def __init__(self, gold_answer, pred_answer):
        self.gold_answer = gold_answer
        self.pred_answer = pred_answer

        self.gold_set = set() # set of Tag
        self.pred_set = set() # set of Tag

        self.type_set = set() # set of label type str
        self.gold_label_counter = Counter() # Counter of gold label type

        self.resault_score = {}

    def _lines_to_tag_set(self, lines, set_type): # set_type: 'gold' or 'pred'
        tags = []
        for i in range(len(lines)):
            try:
                tag = Tag(lines[i])
                tags.append(tag)
            except:
                print(f'Error at {set_type} answer line: {i+1}, {lines[i]}')
        return set(tags)
    
    def _set_filter(self, tag_set, type):
        # tag set filter by type
        return {tag for tag in tag_set if tag.get_type() == type}
    
    def _division(self, a, b):
        try:
            return a / b
        except:
            return 0.0

    def _f1_score(self, TP=None, FP=None, FN=None):
        if TP is None or FP is None or FN is None:
            raise ValueError('TP, FP, FN should be given.')

        precision = self._division(TP, TP + FP)
        recall = self._division(TP, TP + FN)
        f1 = self._division(2 * precision * recall, precision + recall)
            
        return {'precision': precision, 'recall': recall, 'f1': f1}            
        

    def eval(self, ignore_no_gold_tag_file=True):
        with open(self.gold_answer, 'r') as f:
            gold_line = f.readlines()
        # with open(self.pred_answer, 'r') as f:
        #     pred_line = f.readlines()
        ########## add to support the input is a file object ##########
        if isinstance(self.pred_answer, str):
            with open(self.pred_answer, 'r') as f:
                pred_line = f.readlines()


        else:
            pred_line = self.pred_answer.readlines()
            #pred_line is bytes, need to decode
            pred_line = [line.decode('utf-8') for line in pred_line]

        self.gold_set = self._lines_to_tag_set(gold_line, 'gold')
        self.pred_set = self._lines_to_tag_set(pred_line, 'pred')

        # in islab aicup program, it will ignore the files that have no gold tags
        # that program only consider the files that write in gold answer.txt
        if ignore_no_gold_tag_file:
            # filter the files that have no gold tags
            gold_files = {tag.get_file_id() for tag in self.gold_set}
            self.pred_set = {tag for tag in self.pred_set if tag.get_file_id() in gold_files}

        # statistics tags and types 
        for tag in self.gold_set:
            self.type_set.add(tag.get_type())
            self.gold_label_counter[tag.get_type()] += 1
        for tag in self.pred_set:
            self.type_set.add(tag.get_type())

        TP_set = self.gold_set & self.pred_set
        FP_set = self.pred_set - self.gold_set
        FN_set = self.gold_set - self.pred_set

        # count each type of label
        for label in self.type_set:
            filter_TP = self._set_filter(TP_set, label)
            filter_FP = self._set_filter(FP_set, label)
            filter_FN = self._set_filter(FN_set, label)
            score = self._f1_score(len(filter_TP), len(filter_FP), len(filter_FN))
            self.resault_score[label] = score
       
        # MICRO_AVERAGE
        self.resault_score['MICRO_AVERAGE'] = self._f1_score(len(TP_set), len(FP_set), len(FN_set))

        # MACRO_AVERAGE
        precision_sum = 0
        recall_sum = 0
        # f1_sum = 0 # at aicup, calc by MACRO_AVERAGE precision and recall
        for label in self.type_set:
            precision_sum += self.resault_score[label]['precision']
            recall_sum += self.resault_score[label]['recall']
            # f1_sum += self.resault_score[label]['f1']
        
        precision = self._division(precision_sum, len(self.type_set))
        recall = self._division(recall_sum, len(self.type_set))
        # f1 = 2 * precision * recall / (precision + recall)
        f1 = self._division(2 * precision * recall , (precision + recall))

        self.resault_score['MACRO_AVERAGE'] = {'precision': precision, 'recall': recall, 'f1': f1}

        # add Support to each type of label
        for label in self.type_set:
            self.resault_score[label]['support'] = self.gold_label_counter[label]
        self.resault_score['MICRO_AVERAGE']['support'] = len(self.gold_set)
        self.resault_score['MACRO_AVERAGE']['support'] = len(self.gold_set)

        # return json.dumps(self.resault_score, indent=4)
        return self.resault_score


if __name__=="__main__":
    # with open('.output/[meta-llama@Llama-2-7b-hf][Setting3][icl]answer.txt', 'r', encoding='utf-8') as f:
    #     lines = [line.strip() for line in f.readlines() if line.strip() != '']

    # gold_path = 'dataset/Setting3_test_answer.txt'
    # pred_path = '.output/EleutherAI-pythia-1b-Setting3_answer.txt'


    # gold_path = './.output/test_eval/gold_answer.txt'
    # pred_path = './.output/test_eval/pred_answer.txt'

    gold_path = 'dataset/Setting3_test_answer.txt'
    pred_path = '.output/[meta-llama@Llama-2-7b-hf][Setting3][icl]answer.txt'


    eval = Evaluation_answer_txt(gold_path, pred_path)
    res = eval.eval()
    print(res)