First try entity_ratio option
Browse files- FairEval.py +41 -45
FairEval.py
CHANGED
@@ -147,21 +147,24 @@ class FairEvaluation(evaluate.Metric):
|
|
147 |
true_spans = seq_to_fair(true_spans)
|
148 |
pred_spans = seq_to_fair(pred_spans)
|
149 |
|
150 |
-
# (3) COUNT ERRORS AND CALCULATE SCORES
|
151 |
total_errors = compare_spans([], [])
|
|
|
152 |
for i in range(len(true_spans)):
|
|
|
153 |
sentence_errors = compare_spans(true_spans[i], pred_spans[i])
|
154 |
total_errors = add_dict(total_errors, sentence_errors)
|
155 |
|
156 |
if weights is None and mode == 'weighted':
|
157 |
-
print("The chosen mode is \'weighted\', but no weights are given. Setting weights to:\n")
|
158 |
weights = {"TP": {"TP": 1},
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
print(weights)
|
|
|
|
|
165 |
|
166 |
config = {"labels": "all", "eval_method": [mode], "weights": weights,}
|
167 |
results = calculate_results(total_errors, config)
|
@@ -170,34 +173,36 @@ class FairEvaluation(evaluate.Metric):
|
|
170 |
# (4) SELECT OUTPUT MODE AND REFORMAT AS SEQEVAL-HUGGINGFACE OUTPUT
|
171 |
# initialize empty dictionary and count errors
|
172 |
output = {}
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
|
178 |
# assert valid options
|
179 |
assert mode in ['traditional', 'fair', 'weighted'], 'mode must be \'traditional\', \'fair\' or \'weighted\''
|
180 |
-
assert error_format in ['count', '
|
181 |
|
182 |
# append entity-level errors and scores
|
183 |
if mode == 'traditional':
|
184 |
for k, v in results['per_label'][mode].items():
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
elif error_format == 'proportion':
|
189 |
-
output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'], 'TP': v['TP'],
|
190 |
-
'FP': v['FP'] / total_trad_errors, 'FN': v['FN'] / total_trad_errors}
|
191 |
elif mode == 'fair' or mode == 'weighted':
|
192 |
for k, v in results['per_label'][mode].items():
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'], 'TP': v['TP'],
|
198 |
-
'FP': v['FP'] / total_fair_errors, 'FN': v['FN'] / total_fair_errors,
|
199 |
-
'LE': v['LE'] / total_fair_errors, 'BE': v['BE'] / total_fair_errors,
|
200 |
-
'LBE': v['LBE'] / total_fair_errors}
|
201 |
|
202 |
# append overall scores
|
203 |
output['overall_precision'] = results['overall'][mode]['Prec']
|
@@ -206,25 +211,16 @@ class FairEvaluation(evaluate.Metric):
|
|
206 |
|
207 |
# append overall error counts
|
208 |
if mode == 'traditional':
|
209 |
-
output['TP'] = results['overall'][mode]['TP']
|
210 |
-
output['FP'] = results['overall'][mode]['FP']
|
211 |
-
output['FN'] = results['overall'][mode]['FN']
|
212 |
-
if error_format == 'proportion':
|
213 |
-
output['FP'] = output['FP'] / total_trad_errors
|
214 |
-
output['FN'] = output['FN'] / total_trad_errors
|
215 |
elif mode == 'fair' or 'weighted':
|
216 |
-
output['TP'] = results['overall'][mode]['TP']
|
217 |
-
output['FP'] = results['overall'][mode]['FP']
|
218 |
-
output['FN'] = results['overall'][mode]['FN']
|
219 |
-
output['LE'] = results['overall'][mode]['LE']
|
220 |
-
output['BE'] = results['overall'][mode]['BE']
|
221 |
-
output['LBE'] = results['overall'][mode]['LBE']
|
222 |
-
if error_format == 'proportion':
|
223 |
-
output['FP'] = output['FP'] / total_fair_errors
|
224 |
-
output['FN'] = output['FN'] / total_fair_errors
|
225 |
-
output['LE'] = output['LE'] / total_fair_errors
|
226 |
-
output['BE'] = output['BE'] / total_fair_errors
|
227 |
-
output['LBE'] = output['LBE'] / total_fair_errors
|
228 |
|
229 |
return output
|
230 |
|
|
|
147 |
true_spans = seq_to_fair(true_spans)
|
148 |
pred_spans = seq_to_fair(pred_spans)
|
149 |
|
150 |
+
# (3) COUNT ERRORS AND CALCULATE SCORES (counting total ground truth entities too)
|
151 |
total_errors = compare_spans([], [])
|
152 |
+
total_ref_entities = 0
|
153 |
for i in range(len(true_spans)):
|
154 |
+
total_ref_entities += len(true_spans[i])
|
155 |
sentence_errors = compare_spans(true_spans[i], pred_spans[i])
|
156 |
total_errors = add_dict(total_errors, sentence_errors)
|
157 |
|
158 |
if weights is None and mode == 'weighted':
|
|
|
159 |
weights = {"TP": {"TP": 1},
|
160 |
+
"FP": {"FP": 1},
|
161 |
+
"FN": {"FN": 1},
|
162 |
+
"LE": {"TP": 0, "FP": 0.5, "FN": 0.5},
|
163 |
+
"BE": {"TP": 0.5, "FP": 0.25, "FN": 0.25},
|
164 |
+
"LBE": {"TP": 0, "FP": 0.5, "FN": 0.5}}
|
165 |
+
print("The chosen mode is \'weighted\', but no weights are given. Setting weights to:\n")
|
166 |
+
for k in weights:
|
167 |
+
print(k, ":", weights[k])
|
168 |
|
169 |
config = {"labels": "all", "eval_method": [mode], "weights": weights,}
|
170 |
results = calculate_results(total_errors, config)
|
|
|
173 |
# (4) SELECT OUTPUT MODE AND REFORMAT AS SEQEVAL-HUGGINGFACE OUTPUT
|
174 |
# initialize empty dictionary and count errors
|
175 |
output = {}
|
176 |
+
# control the divider for the error_format (count, proportion over total errors or over total entities)
|
177 |
+
if error_format == 'count':
|
178 |
+
trad_divider = 1,
|
179 |
+
fair_divider = 1,
|
180 |
+
elif error_format == 'entity_ratio':
|
181 |
+
trad_divider = total_ref_entities
|
182 |
+
fair_divider = total_ref_entities
|
183 |
+
elif error_format == 'error_ratio':
|
184 |
+
trad_divider = results['overall']['traditional']['FP'] + results['overall']['traditional']['FN']
|
185 |
+
fair_divider = results['overall']['fair']['FP'] + results['overall']['fair']['FN'] + \
|
186 |
+
results['overall']['fair']['LE'] + results['overall']['fair']['BE'] + \
|
187 |
+
results['overall']['fair']['LBE']
|
188 |
+
|
189 |
|
190 |
# assert valid options
|
191 |
assert mode in ['traditional', 'fair', 'weighted'], 'mode must be \'traditional\', \'fair\' or \'weighted\''
|
192 |
+
assert error_format in ['count', 'error_ratio', 'entity_ratio'], 'error_format must be \'count\', \'error_ratio\' or \'entity_ratio\''
|
193 |
|
194 |
# append entity-level errors and scores
|
195 |
if mode == 'traditional':
|
196 |
for k, v in results['per_label'][mode].items():
|
197 |
+
output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'],
|
198 |
+
'TP': v['TP'] / trad_divider if error_format == 'entity_ratio' else v['TP'],
|
199 |
+
'FP': v['FP'] / trad_divider, 'FN': v['FN'] / trad_divider}
|
|
|
|
|
|
|
200 |
elif mode == 'fair' or mode == 'weighted':
|
201 |
for k, v in results['per_label'][mode].items():
|
202 |
+
output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'],
|
203 |
+
'TP': v['TP'] / fair_divider if error_format == 'entity_ratio' else v['TP'],
|
204 |
+
'FP': v['FP'] / fair_divider, 'FN': v['FN'] / fair_divider,
|
205 |
+
'LE': v['LE'] / fair_divider, 'BE': v['BE'] / fair_divider, 'LBE': v['LBE'] / fair_divider}
|
|
|
|
|
|
|
|
|
206 |
|
207 |
# append overall scores
|
208 |
output['overall_precision'] = results['overall'][mode]['Prec']
|
|
|
211 |
|
212 |
# append overall error counts
|
213 |
if mode == 'traditional':
|
214 |
+
output['TP'] = results['overall'][mode]['TP'] / trad_divider if error_format == 'entity_ratio' else results['overall'][mode]['TP']
|
215 |
+
output['FP'] = results['overall'][mode]['FP'] / trad_divider
|
216 |
+
output['FN'] = results['overall'][mode]['FN'] / trad_divider
|
|
|
|
|
|
|
217 |
elif mode == 'fair' or 'weighted':
|
218 |
+
output['TP'] = results['overall'][mode]['TP'] / fair_divider if error_format == 'entity_ratio' else results['overall'][mode]['TP']
|
219 |
+
output['FP'] = results['overall'][mode]['FP'] / fair_divider
|
220 |
+
output['FN'] = results['overall'][mode]['FN'] / fair_divider
|
221 |
+
output['LE'] = results['overall'][mode]['LE'] / fair_divider
|
222 |
+
output['BE'] = results['overall'][mode]['BE'] / fair_divider
|
223 |
+
output['LBE'] = results['overall'][mode]['LBE'] / fair_divider
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
|
225 |
return output
|
226 |
|