File size: 7,026 Bytes
2f22782 e1ad072 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import re
import os
from collections import Counter
import json
class Tag:
def __init__(self, txt_line:str):
# | file_name | label_type | label_start | label_end | label_text |
# match = re.match(r'(.+)\t(\w+)\t(\d+)\t(\d+)\t(.+)', txt_line)
try:
sep = txt_line.strip().split('\t')
self.file_id = sep[0]
self.type = sep[1]
self.start = sep[2] # int(sep[2])
self.end = sep[3] # int(sep[3])
self.text = sep[4]
except:
raise ValueError('The format of the input line is not correct. Please check the input line format.')
def get_type(self):
return self.type
def get_file_id(self):
return self.file_id
def __eq__(self, other: 'Tag'):
# if all file_id, type, start, end, are the same, return True
# text is not considered for the comparison
ck_file_id = self.file_id == other.file_id
ck_type = self.type == other.type
ck_start = self.start == other.start
ck_end = self.end == other.end
# ck_text = self.text == other.text
if ck_file_id and ck_type and ck_start and ck_end:
return True
else:
return False
def __repr__(self):
return f'<{self.__class__.__name__} {self.file_id:10} {self.type:10} s:{self.start:5} e:{self.end:5} {self.text}>\n'
def __hash__(self):
return hash((self.file_id, self.type, self.start, self.end))
class Evaluation_answer_txt:
def __init__(self, gold_answer, pred_answer):
self.gold_answer = gold_answer
self.pred_answer = pred_answer
self.gold_set = set() # set of Tag
self.pred_set = set() # set of Tag
self.type_set = set() # set of label type str
self.gold_label_counter = Counter() # Counter of gold label type
self.resault_score = {}
def _lines_to_tag_set(self, lines, set_type): # set_type: 'gold' or 'pred'
tags = []
for i in range(len(lines)):
try:
tag = Tag(lines[i])
tags.append(tag)
except:
print(f'Error at {set_type} answer line: {i+1}, {lines[i]}')
return set(tags)
def _set_filter(self, tag_set, type):
# tag set filter by type
return {tag for tag in tag_set if tag.get_type() == type}
def _division(self, a, b):
try:
return a / b
except:
return 0.0
def _f1_score(self, TP=None, FP=None, FN=None):
if TP is None or FP is None or FN is None:
raise ValueError('TP, FP, FN should be given.')
precision = self._division(TP, TP + FP)
recall = self._division(TP, TP + FN)
f1 = self._division(2 * precision * recall, precision + recall)
return {'precision': precision, 'recall': recall, 'f1': f1}
def eval(self, ignore_no_gold_tag_file=True):
with open(self.gold_answer, 'r') as f:
gold_line = f.readlines()
# with open(self.pred_answer, 'r') as f:
# pred_line = f.readlines()
########## add to support the input is a file object ##########
if isinstance(self.pred_answer, str):
with open(self.pred_answer, 'r') as f:
pred_line = f.readlines()
else:
pred_line = self.pred_answer.readlines()
#pred_line is bytes, need to decode
pred_line = [line.decode('utf-8') for line in pred_line]
self.gold_set = self._lines_to_tag_set(gold_line, 'gold')
self.pred_set = self._lines_to_tag_set(pred_line, 'pred')
# in islab aicup program, it will ignore the files that have no gold tags
# that program only consider the files that write in gold answer.txt
if ignore_no_gold_tag_file:
# filter the files that have no gold tags
gold_files = {tag.get_file_id() for tag in self.gold_set}
self.pred_set = {tag for tag in self.pred_set if tag.get_file_id() in gold_files}
# statistics tags and types
for tag in self.gold_set:
self.type_set.add(tag.get_type())
self.gold_label_counter[tag.get_type()] += 1
for tag in self.pred_set:
self.type_set.add(tag.get_type())
TP_set = self.gold_set & self.pred_set
FP_set = self.pred_set - self.gold_set
FN_set = self.gold_set - self.pred_set
# count each type of label
for label in self.type_set:
filter_TP = self._set_filter(TP_set, label)
filter_FP = self._set_filter(FP_set, label)
filter_FN = self._set_filter(FN_set, label)
score = self._f1_score(len(filter_TP), len(filter_FP), len(filter_FN))
self.resault_score[label] = score
# MICRO_AVERAGE
self.resault_score['MICRO_AVERAGE'] = self._f1_score(len(TP_set), len(FP_set), len(FN_set))
# MACRO_AVERAGE
precision_sum = 0
recall_sum = 0
# f1_sum = 0 # at aicup, calc by MACRO_AVERAGE precision and recall
for label in self.type_set:
precision_sum += self.resault_score[label]['precision']
recall_sum += self.resault_score[label]['recall']
# f1_sum += self.resault_score[label]['f1']
precision = self._division(precision_sum, len(self.type_set))
recall = self._division(recall_sum, len(self.type_set))
# f1 = 2 * precision * recall / (precision + recall)
f1 = self._division(2 * precision * recall , (precision + recall))
self.resault_score['MACRO_AVERAGE'] = {'precision': precision, 'recall': recall, 'f1': f1}
# add Support to each type of label
for label in self.type_set:
self.resault_score[label]['support'] = self.gold_label_counter[label]
self.resault_score['MICRO_AVERAGE']['support'] = len(self.gold_set)
self.resault_score['MACRO_AVERAGE']['support'] = len(self.gold_set)
# return json.dumps(self.resault_score, indent=4)
return self.resault_score
if __name__=="__main__":
# with open('.output/[meta-llama@Llama-2-7b-hf][Setting3][icl]answer.txt', 'r', encoding='utf-8') as f:
# lines = [line.strip() for line in f.readlines() if line.strip() != '']
# gold_path = 'dataset/Setting3_test_answer.txt'
# pred_path = '.output/EleutherAI-pythia-1b-Setting3_answer.txt'
# gold_path = './.output/test_eval/gold_answer.txt'
# pred_path = './.output/test_eval/pred_answer.txt'
gold_path = 'dataset/Setting3_test_answer.txt'
pred_path = '.output/[meta-llama@Llama-2-7b-hf][Setting3][icl]answer.txt'
eval = Evaluation_answer_txt(gold_path, pred_path)
res = eval.eval()
print(res) |