illorca commited on
Commit
fa93db6
1 Parent(s): 53ac266

Avoid dict overrides for entity-level

Browse files
Files changed (2) hide show
  1. FairEval.py +38 -18
  2. HFFE_use_cases.pdf +0 -0
FairEval.py CHANGED
@@ -204,38 +204,58 @@ class FairEval(evaluate.Metric):
204
  assert mode in ['traditional', 'fair', 'weighted'], 'mode must be \'traditional\', \'fair\' or \'weighted\''
205
  assert error_format in ['count', 'error_ratio', 'entity_ratio'], 'error_format must be \'count\', \'error_ratio\' or \'entity_ratio\''
206
 
207
- # append entity-level errors (always fair)
208
- for k, v in results['per_label']['fair'].items():
209
- output[k] = {'TP': v['TP'] / fair_divider if error_format == 'entity_ratio' else v['TP'],
210
- 'FP': v['FP'] / fair_divider, 'FN': v['FN'] / fair_divider,
211
- 'LE': v['LE'] / fair_divider, 'BE': v['BE'] / fair_divider, 'LBE': v['LBE'] / fair_divider,}
212
-
213
- # append entity-level scores (depending on mode)
214
  if mode == 'traditional':
215
  for k, v in results['per_label'][mode].items():
216
- output[k].update({'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'],})
 
 
 
 
 
217
  elif mode == 'fair' or mode == 'weighted':
218
  for k, v in results['per_label'][mode].items():
219
- output[k].update({'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'],})
 
 
 
 
 
 
220
 
221
- # append overall scores (depending on mode)
 
 
 
 
 
222
  output['overall_precision'] = results['overall'][mode]['Prec']
223
  output['overall_recall'] = results['overall'][mode]['Rec']
224
  output['overall_f1'] = results['overall'][mode]['F1']
225
 
226
- # append overall error counts (always fair)
227
- output['TP'] = results['overall']['fair']['TP'] / fair_divider if error_format == 'entity_ratio' else results['overall'][mode]['TP']
228
- output['FP'] = results['overall']['fair']['FP'] / fair_divider
229
- output['FN'] = results['overall']['fair']['FN'] / fair_divider
230
- output['LE'] = results['overall']['fair']['LE'] / fair_divider
231
- output['BE'] = results['overall']['fair']['BE'] / fair_divider
232
- output['LBE'] = results['overall']['fair']['LBE'] / fair_divider
 
 
 
 
 
 
 
 
 
 
233
 
234
  return output
235
 
236
 
237
  def seq_to_fair(seq_sentences):
238
- "Transforms input anotated sentences from seqeval span format to FairEval span format"
239
  out = []
240
  for seq_sentence in seq_sentences:
241
  sentence = []
 
204
  assert mode in ['traditional', 'fair', 'weighted'], 'mode must be \'traditional\', \'fair\' or \'weighted\''
205
  assert error_format in ['count', 'error_ratio', 'entity_ratio'], 'error_format must be \'count\', \'error_ratio\' or \'entity_ratio\''
206
 
207
+ # append entity-level errors and scores
 
 
 
 
 
 
208
  if mode == 'traditional':
209
  for k, v in results['per_label'][mode].items():
210
+ output[k] = {# traditional scores
211
+ 'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'],
212
+
213
+ # traditional errors
214
+ 'TP': v['TP'] / trad_divider if error_format == 'entity_ratio' else v['TP'],
215
+ 'FP': v['FP'] / trad_divider, 'FN': v['FN'] / trad_divider}
216
  elif mode == 'fair' or mode == 'weighted':
217
  for k, v in results['per_label'][mode].items():
218
+ output[k] = {# fair/weighted scores
219
+ 'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'],
220
+
221
+ # traditional scores
222
+ 'trad_prec': results['per_label']['traditional'][k]['Prec'],
223
+ 'trad_rec': results['per_label']['traditional'][k]['Rec'],
224
+ 'trad_f1': results['per_label']['traditional'][k]['F1'],
225
 
226
+ # fair/weighted errors
227
+ 'TP': v['TP'] / fair_divider if error_format == 'entity_ratio' else v['TP'],
228
+ 'FP': v['FP'] / fair_divider, 'FN': v['FN'] / fair_divider,
229
+ 'LE': v['LE'] / fair_divider, 'BE': v['BE'] / fair_divider, 'LBE': v['LBE'] / fair_divider}
230
+
231
+ # append overall scores
232
  output['overall_precision'] = results['overall'][mode]['Prec']
233
  output['overall_recall'] = results['overall'][mode]['Rec']
234
  output['overall_f1'] = results['overall'][mode]['F1']
235
 
236
+ # append overall error counts (and trad scores if mode is fair)
237
+ if mode == 'traditional':
238
+ output['TP'] = results['overall'][mode]['TP'] / trad_divider if error_format == 'entity_ratio' else \
239
+ results['overall'][mode]['TP']
240
+ output['FP'] = results['overall'][mode]['FP'] / trad_divider
241
+ output['FN'] = results['overall'][mode]['FN'] / trad_divider
242
+ elif mode == 'fair' or 'weighted':
243
+ output['overall_trad_prec'] = results['overall']['traditional']['Prec']
244
+ output['overall_trad_rec'] = results['overall']['traditional']['Rec']
245
+ output['overall_trad_f1'] = results['overall']['traditional']['F1']
246
+ output['TP'] = results['overall'][mode]['TP'] / fair_divider if error_format == 'entity_ratio' else \
247
+ results['overall'][mode]['TP']
248
+ output['FP'] = results['overall'][mode]['FP'] / fair_divider
249
+ output['FN'] = results['overall'][mode]['FN'] / fair_divider
250
+ output['LE'] = results['overall'][mode]['LE'] / fair_divider
251
+ output['BE'] = results['overall'][mode]['BE'] / fair_divider
252
+ output['LBE'] = results['overall'][mode]['LBE'] / fair_divider
253
 
254
  return output
255
 
256
 
257
  def seq_to_fair(seq_sentences):
258
+ "Transforms input annotated sentences from seqeval span format to FairEval span format"
259
  out = []
260
  for seq_sentence in seq_sentences:
261
  sentence = []
HFFE_use_cases.pdf DELETED
Binary file (86.4 kB)