illorca commited on
Commit
a92bada
1 Parent(s): e898efc

Traditional mode shows fair errors

Browse files
Files changed (1) hide show
  1. FairEval.py +18 -22
FairEval.py CHANGED
@@ -180,7 +180,7 @@ class FairEval(evaluate.Metric):
180
  for k in weights:
181
  print(k, ":", weights[k])
182
 
183
- config = {"labels": "all", "eval_method": [mode], "weights": weights,}
184
  results = calculate_results(total_errors, config)
185
  del results['conf']
186
 
@@ -204,36 +204,32 @@ class FairEval(evaluate.Metric):
204
  assert mode in ['traditional', 'fair', 'weighted'], 'mode must be \'traditional\', \'fair\' or \'weighted\''
205
  assert error_format in ['count', 'error_ratio', 'entity_ratio'], 'error_format must be \'count\', \'error_ratio\' or \'entity_ratio\''
206
 
207
- # append entity-level errors and scores
 
 
 
 
 
 
208
  if mode == 'traditional':
209
  for k, v in results['per_label'][mode].items():
210
- output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'],
211
- 'TP': v['TP'] / trad_divider if error_format == 'entity_ratio' else v['TP'],
212
- 'FP': v['FP'] / trad_divider, 'FN': v['FN'] / trad_divider}
213
  elif mode == 'fair' or mode == 'weighted':
214
  for k, v in results['per_label'][mode].items():
215
- output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'],
216
- 'TP': v['TP'] / fair_divider if error_format == 'entity_ratio' else v['TP'],
217
- 'FP': v['FP'] / fair_divider, 'FN': v['FN'] / fair_divider,
218
- 'LE': v['LE'] / fair_divider, 'BE': v['BE'] / fair_divider, 'LBE': v['LBE'] / fair_divider}
219
 
220
- # append overall scores
221
  output['overall_precision'] = results['overall'][mode]['Prec']
222
  output['overall_recall'] = results['overall'][mode]['Rec']
223
  output['overall_f1'] = results['overall'][mode]['F1']
224
 
225
- # append overall error counts
226
- if mode == 'traditional':
227
- output['TP'] = results['overall'][mode]['TP'] / trad_divider if error_format == 'entity_ratio' else results['overall'][mode]['TP']
228
- output['FP'] = results['overall'][mode]['FP'] / trad_divider
229
- output['FN'] = results['overall'][mode]['FN'] / trad_divider
230
- elif mode == 'fair' or 'weighted':
231
- output['TP'] = results['overall'][mode]['TP'] / fair_divider if error_format == 'entity_ratio' else results['overall'][mode]['TP']
232
- output['FP'] = results['overall'][mode]['FP'] / fair_divider
233
- output['FN'] = results['overall'][mode]['FN'] / fair_divider
234
- output['LE'] = results['overall'][mode]['LE'] / fair_divider
235
- output['BE'] = results['overall'][mode]['BE'] / fair_divider
236
- output['LBE'] = results['overall'][mode]['LBE'] / fair_divider
237
 
238
  return output
239
 
 
180
  for k in weights:
181
  print(k, ":", weights[k])
182
 
183
+ config = {"labels": "all", "eval_method": ['traditional', 'fair', 'weighted'], "weights": weights,}
184
  results = calculate_results(total_errors, config)
185
  del results['conf']
186
 
 
204
  assert mode in ['traditional', 'fair', 'weighted'], 'mode must be \'traditional\', \'fair\' or \'weighted\''
205
  assert error_format in ['count', 'error_ratio', 'entity_ratio'], 'error_format must be \'count\', \'error_ratio\' or \'entity_ratio\''
206
 
207
+ # append entity-level errors (always fair)
208
+ for k, v in results['per_label']['fair'].items():
209
+ output[k] = {'TP': v['TP'] / fair_divider if error_format == 'entity_ratio' else v['TP'],
210
+ 'FP': v['FP'] / fair_divider, 'FN': v['FN'] / fair_divider,
211
+ 'LE': v['LE'] / fair_divider, 'BE': v['BE'] / fair_divider, 'LBE': v['LBE'] / fair_divider,}
212
+
213
+ # append entity-level scores (depending on mode)
214
  if mode == 'traditional':
215
  for k, v in results['per_label'][mode].items():
216
+ output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'],}
 
 
217
  elif mode == 'fair' or mode == 'weighted':
218
  for k, v in results['per_label'][mode].items():
219
+ output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'],}
 
 
 
220
 
221
+ # append overall scores (depending on mode)
222
  output['overall_precision'] = results['overall'][mode]['Prec']
223
  output['overall_recall'] = results['overall'][mode]['Rec']
224
  output['overall_f1'] = results['overall'][mode]['F1']
225
 
226
+ # append overall error counts (always fair)
227
+ output['TP'] = results['overall']['fair']['TP'] / fair_divider if error_format == 'entity_ratio' else results['overall'][mode]['TP']
228
+ output['FP'] = results['overall']['fair']['FP'] / fair_divider
229
+ output['FN'] = results['overall']['fair']['FN'] / fair_divider
230
+ output['LE'] = results['overall']['fair']['LE'] / fair_divider
231
+ output['BE'] = results['overall']['fair']['BE'] / fair_divider
232
+ output['LBE'] = results['overall']['fair']['LBE'] / fair_divider
 
 
 
 
 
233
 
234
  return output
235