Include weighted mode. ORIGINAL FAIREVAL SCRIPT IS MODIFIED
Browse files- FairEval.py +23 -7
- FairEvalUtils.py +2 -1
FairEval.py
CHANGED
@@ -119,6 +119,7 @@ class FairEvaluation(evaluate.Metric):
|
|
119 |
suffix: bool = False,
|
120 |
scheme: Optional[str] = None,
|
121 |
mode: Optional[str] = 'fair',
|
|
|
122 |
error_format: Optional[str] = 'count',
|
123 |
zero_division: Union[str, int] = "warn",
|
124 |
):
|
@@ -147,25 +148,38 @@ class FairEvaluation(evaluate.Metric):
|
|
147 |
pred_spans = seq_to_fair(pred_spans)
|
148 |
|
149 |
# (3) COUNT ERRORS AND CALCULATE SCORES
|
150 |
-
total_errors = compare_spans([], [])
|
151 |
-
|
152 |
for i in range(len(true_spans)):
|
153 |
sentence_errors = compare_spans(true_spans[i], pred_spans[i])
|
154 |
total_errors = add_dict(total_errors, sentence_errors)
|
155 |
|
156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
del results['conf']
|
158 |
|
159 |
-
# (4) SELECT OUTPUT MODE AND REFORMAT AS SEQEVAL
|
|
|
160 |
output = {}
|
161 |
total_trad_errors = results['overall']['traditional']['FP'] + results['overall']['traditional']['FN']
|
162 |
total_fair_errors = results['overall']['fair']['FP'] + results['overall']['fair']['FN'] + \
|
163 |
results['overall']['fair']['LE'] + results['overall']['fair']['BE'] + \
|
164 |
results['overall']['fair']['LBE']
|
165 |
|
166 |
-
assert
|
|
|
167 |
assert error_format in ['count', 'proportion'], 'error_format must be \'count\' or \'proportion\''
|
168 |
|
|
|
169 |
if mode == 'traditional':
|
170 |
for k, v in results['per_label'][mode].items():
|
171 |
if error_format == 'count':
|
@@ -174,7 +188,7 @@ class FairEvaluation(evaluate.Metric):
|
|
174 |
elif error_format == 'proportion':
|
175 |
output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'], 'TP': v['TP'],
|
176 |
'FP': v['FP'] / total_trad_errors, 'FN': v['FN'] / total_trad_errors}
|
177 |
-
elif mode == 'fair':
|
178 |
for k, v in results['per_label'][mode].items():
|
179 |
if error_format == 'count':
|
180 |
output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'], 'TP': v['TP'],
|
@@ -185,10 +199,12 @@ class FairEvaluation(evaluate.Metric):
|
|
185 |
'LE': v['LE'] / total_fair_errors, 'BE': v['BE'] / total_fair_errors,
|
186 |
'LBE': v['LBE'] / total_fair_errors}
|
187 |
|
|
|
188 |
output['overall_precision'] = results['overall'][mode]['Prec']
|
189 |
output['overall_recall'] = results['overall'][mode]['Rec']
|
190 |
output['overall_f1'] = results['overall'][mode]['F1']
|
191 |
|
|
|
192 |
if mode == 'traditional':
|
193 |
output['TP'] = results['overall'][mode]['TP']
|
194 |
output['FP'] = results['overall'][mode]['FP']
|
@@ -196,7 +212,7 @@ class FairEvaluation(evaluate.Metric):
|
|
196 |
if error_format == 'proportion':
|
197 |
output['FP'] = output['FP'] / total_trad_errors
|
198 |
output['FN'] = output['FN'] / total_trad_errors
|
199 |
-
elif mode == 'fair':
|
200 |
output['TP'] = results['overall'][mode]['TP']
|
201 |
output['FP'] = results['overall'][mode]['FP']
|
202 |
output['FN'] = results['overall'][mode]['FN']
|
|
|
119 |
suffix: bool = False,
|
120 |
scheme: Optional[str] = None,
|
121 |
mode: Optional[str] = 'fair',
|
122 |
+
weights: dict = None,
|
123 |
error_format: Optional[str] = 'count',
|
124 |
zero_division: Union[str, int] = "warn",
|
125 |
):
|
|
|
148 |
pred_spans = seq_to_fair(pred_spans)
|
149 |
|
150 |
# (3) COUNT ERRORS AND CALCULATE SCORES
|
151 |
+
total_errors = compare_spans([], [])
|
|
|
152 |
for i in range(len(true_spans)):
|
153 |
sentence_errors = compare_spans(true_spans[i], pred_spans[i])
|
154 |
total_errors = add_dict(total_errors, sentence_errors)
|
155 |
|
156 |
+
if weights is None and mode == 'weighted':
|
157 |
+
print("The chosen mode is \'weighted\', but no weights are given. Setting weights to:\n")
|
158 |
+
weights = {"TP": {"TP": 1},
|
159 |
+
"FP": {"FP": 1},
|
160 |
+
"FN": {"FN": 1},
|
161 |
+
"LE": {"TP": 0, "FP": 0.5, "FN": 0.5},
|
162 |
+
"BE": {"TP": 0.5, "FP": 0.25, "FN": 0.25},
|
163 |
+
"LBE": {"TP": 0, "FP": 0.5, "FN": 0.5}}
|
164 |
+
print(weights)
|
165 |
+
|
166 |
+
config = {"labels": "all", "eval_method": [mode], "weights": weights,}
|
167 |
+
results = calculate_results(total_errors, config)
|
168 |
del results['conf']
|
169 |
|
170 |
+
# (4) SELECT OUTPUT MODE AND REFORMAT AS SEQEVAL-HUGGINGFACE OUTPUT
|
171 |
+
# initialize empty dictionary and count errors
|
172 |
output = {}
|
173 |
total_trad_errors = results['overall']['traditional']['FP'] + results['overall']['traditional']['FN']
|
174 |
total_fair_errors = results['overall']['fair']['FP'] + results['overall']['fair']['FN'] + \
|
175 |
results['overall']['fair']['LE'] + results['overall']['fair']['BE'] + \
|
176 |
results['overall']['fair']['LBE']
|
177 |
|
178 |
+
# assert valid options
|
179 |
+
assert mode in ['traditional', 'fair', 'weighted'], 'mode must be \'traditional\', \'fair\' or \'weighted\''
|
180 |
assert error_format in ['count', 'proportion'], 'error_format must be \'count\' or \'proportion\''
|
181 |
|
182 |
+
# append entity-level errors and scores
|
183 |
if mode == 'traditional':
|
184 |
for k, v in results['per_label'][mode].items():
|
185 |
if error_format == 'count':
|
|
|
188 |
elif error_format == 'proportion':
|
189 |
output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'], 'TP': v['TP'],
|
190 |
'FP': v['FP'] / total_trad_errors, 'FN': v['FN'] / total_trad_errors}
|
191 |
+
elif mode == 'fair' or mode == 'weighted':
|
192 |
for k, v in results['per_label'][mode].items():
|
193 |
if error_format == 'count':
|
194 |
output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'], 'TP': v['TP'],
|
|
|
199 |
'LE': v['LE'] / total_fair_errors, 'BE': v['BE'] / total_fair_errors,
|
200 |
'LBE': v['LBE'] / total_fair_errors}
|
201 |
|
202 |
+
# append overall scores
|
203 |
output['overall_precision'] = results['overall'][mode]['Prec']
|
204 |
output['overall_recall'] = results['overall'][mode]['Rec']
|
205 |
output['overall_f1'] = results['overall'][mode]['F1']
|
206 |
|
207 |
+
# append overall error counts
|
208 |
if mode == 'traditional':
|
209 |
output['TP'] = results['overall'][mode]['TP']
|
210 |
output['FP'] = results['overall'][mode]['FP']
|
|
|
212 |
if error_format == 'proportion':
|
213 |
output['FP'] = output['FP'] / total_trad_errors
|
214 |
output['FN'] = output['FN'] / total_trad_errors
|
215 |
+
elif mode == 'fair' or 'weighted':
|
216 |
output['TP'] = results['overall'][mode]['TP']
|
217 |
output['FP'] = results['overall'][mode]['FP']
|
218 |
output['FN'] = results['overall'][mode]['FN']
|
FairEvalUtils.py
CHANGED
@@ -1149,7 +1149,7 @@ def add_dict(base_dict, dict_to_add):
|
|
1149 |
|
1150 |
#############################
|
1151 |
|
1152 |
-
def calculate_results(eval_dict,
|
1153 |
"""
|
1154 |
Calculate overall precision, recall, and F-scores.
|
1155 |
|
@@ -1173,6 +1173,7 @@ def calculate_results(eval_dict, **config):
|
|
1173 |
eval_dict["overall"]["weighted"] = {}
|
1174 |
for err_type in eval_dict["overall"]["fair"]:
|
1175 |
eval_dict["overall"]["weighted"][err_type] = eval_dict["overall"]["fair"][err_type]
|
|
|
1176 |
for label in eval_dict["per_label"]["fair"]:
|
1177 |
eval_dict["per_label"]["weighted"][label] = {}
|
1178 |
for err_type in eval_dict["per_label"]["fair"][label]:
|
|
|
1149 |
|
1150 |
#############################
|
1151 |
|
1152 |
+
def calculate_results(eval_dict, config):
|
1153 |
"""
|
1154 |
Calculate overall precision, recall, and F-scores.
|
1155 |
|
|
|
1173 |
eval_dict["overall"]["weighted"] = {}
|
1174 |
for err_type in eval_dict["overall"]["fair"]:
|
1175 |
eval_dict["overall"]["weighted"][err_type] = eval_dict["overall"]["fair"][err_type]
|
1176 |
+
eval_dict["per_label"]["weighted"] = {}
|
1177 |
for label in eval_dict["per_label"]["fair"]:
|
1178 |
eval_dict["per_label"]["weighted"][label] = {}
|
1179 |
for err_type in eval_dict["per_label"]["fair"][label]:
|