Avoid dict overrides for entity-level
Browse files- FairEval.py +38 -18
- HFFE_use_cases.pdf +0 -0
FairEval.py
CHANGED
@@ -204,38 +204,58 @@ class FairEval(evaluate.Metric):
|
|
204 |
assert mode in ['traditional', 'fair', 'weighted'], 'mode must be \'traditional\', \'fair\' or \'weighted\''
|
205 |
assert error_format in ['count', 'error_ratio', 'entity_ratio'], 'error_format must be \'count\', \'error_ratio\' or \'entity_ratio\''
|
206 |
|
207 |
-
# append entity-level errors
|
208 |
-
for k, v in results['per_label']['fair'].items():
|
209 |
-
output[k] = {'TP': v['TP'] / fair_divider if error_format == 'entity_ratio' else v['TP'],
|
210 |
-
'FP': v['FP'] / fair_divider, 'FN': v['FN'] / fair_divider,
|
211 |
-
'LE': v['LE'] / fair_divider, 'BE': v['BE'] / fair_divider, 'LBE': v['LBE'] / fair_divider,}
|
212 |
-
|
213 |
-
# append entity-level scores (depending on mode)
|
214 |
if mode == 'traditional':
|
215 |
for k, v in results['per_label'][mode].items():
|
216 |
-
output[k]
|
|
|
|
|
|
|
|
|
|
|
217 |
elif mode == 'fair' or mode == 'weighted':
|
218 |
for k, v in results['per_label'][mode].items():
|
219 |
-
output[k]
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
|
221 |
-
|
|
|
|
|
|
|
|
|
|
|
222 |
output['overall_precision'] = results['overall'][mode]['Prec']
|
223 |
output['overall_recall'] = results['overall'][mode]['Rec']
|
224 |
output['overall_f1'] = results['overall'][mode]['F1']
|
225 |
|
226 |
-
# append overall error counts (
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
return output
|
235 |
|
236 |
|
237 |
def seq_to_fair(seq_sentences):
|
238 |
-
"Transforms input
|
239 |
out = []
|
240 |
for seq_sentence in seq_sentences:
|
241 |
sentence = []
|
|
|
204 |
assert mode in ['traditional', 'fair', 'weighted'], 'mode must be \'traditional\', \'fair\' or \'weighted\''
|
205 |
assert error_format in ['count', 'error_ratio', 'entity_ratio'], 'error_format must be \'count\', \'error_ratio\' or \'entity_ratio\''
|
206 |
|
207 |
+
# append entity-level errors and scores
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
if mode == 'traditional':
|
209 |
for k, v in results['per_label'][mode].items():
|
210 |
+
output[k] = {# traditional scores
|
211 |
+
'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'],
|
212 |
+
|
213 |
+
# traditional errors
|
214 |
+
'TP': v['TP'] / trad_divider if error_format == 'entity_ratio' else v['TP'],
|
215 |
+
'FP': v['FP'] / trad_divider, 'FN': v['FN'] / trad_divider}
|
216 |
elif mode == 'fair' or mode == 'weighted':
|
217 |
for k, v in results['per_label'][mode].items():
|
218 |
+
output[k] = {# fair/weighted scores
|
219 |
+
'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'],
|
220 |
+
|
221 |
+
# traditional scores
|
222 |
+
'trad_prec': results['per_label']['traditional'][k]['Prec'],
|
223 |
+
'trad_rec': results['per_label']['traditional'][k]['Rec'],
|
224 |
+
'trad_f1': results['per_label']['traditional'][k]['F1'],
|
225 |
|
226 |
+
# fair/weighted errors
|
227 |
+
'TP': v['TP'] / fair_divider if error_format == 'entity_ratio' else v['TP'],
|
228 |
+
'FP': v['FP'] / fair_divider, 'FN': v['FN'] / fair_divider,
|
229 |
+
'LE': v['LE'] / fair_divider, 'BE': v['BE'] / fair_divider, 'LBE': v['LBE'] / fair_divider}
|
230 |
+
|
231 |
+
# append overall scores
|
232 |
output['overall_precision'] = results['overall'][mode]['Prec']
|
233 |
output['overall_recall'] = results['overall'][mode]['Rec']
|
234 |
output['overall_f1'] = results['overall'][mode]['F1']
|
235 |
|
236 |
+
# append overall error counts (and trad scores if mode is fair)
|
237 |
+
if mode == 'traditional':
|
238 |
+
output['TP'] = results['overall'][mode]['TP'] / trad_divider if error_format == 'entity_ratio' else \
|
239 |
+
results['overall'][mode]['TP']
|
240 |
+
output['FP'] = results['overall'][mode]['FP'] / trad_divider
|
241 |
+
output['FN'] = results['overall'][mode]['FN'] / trad_divider
|
242 |
+
elif mode == 'fair' or 'weighted':
|
243 |
+
output['overall_trad_prec'] = results['overall']['traditional']['Prec']
|
244 |
+
output['overall_trad_rec'] = results['overall']['traditional']['Rec']
|
245 |
+
output['overall_trad_f1'] = results['overall']['traditional']['F1']
|
246 |
+
output['TP'] = results['overall'][mode]['TP'] / fair_divider if error_format == 'entity_ratio' else \
|
247 |
+
results['overall'][mode]['TP']
|
248 |
+
output['FP'] = results['overall'][mode]['FP'] / fair_divider
|
249 |
+
output['FN'] = results['overall'][mode]['FN'] / fair_divider
|
250 |
+
output['LE'] = results['overall'][mode]['LE'] / fair_divider
|
251 |
+
output['BE'] = results['overall'][mode]['BE'] / fair_divider
|
252 |
+
output['LBE'] = results['overall'][mode]['LBE'] / fair_divider
|
253 |
|
254 |
return output
|
255 |
|
256 |
|
257 |
def seq_to_fair(seq_sentences):
|
258 |
+
"Transforms input annotated sentences from seqeval span format to FairEval span format"
|
259 |
out = []
|
260 |
for seq_sentence in seq_sentences:
|
261 |
sentence = []
|
HFFE_use_cases.pdf
DELETED
Binary file (86.4 kB)
|
|