illorca commited on
Commit
066589e
·
1 Parent(s): d8424e9

First try entity_ratio option

Browse files
Files changed (1) hide show
  1. FairEval.py +41 -45
FairEval.py CHANGED
@@ -147,21 +147,24 @@ class FairEvaluation(evaluate.Metric):
147
  true_spans = seq_to_fair(true_spans)
148
  pred_spans = seq_to_fair(pred_spans)
149
 
150
- # (3) COUNT ERRORS AND CALCULATE SCORES
151
  total_errors = compare_spans([], [])
 
152
  for i in range(len(true_spans)):
 
153
  sentence_errors = compare_spans(true_spans[i], pred_spans[i])
154
  total_errors = add_dict(total_errors, sentence_errors)
155
 
156
  if weights is None and mode == 'weighted':
157
- print("The chosen mode is \'weighted\', but no weights are given. Setting weights to:\n")
158
  weights = {"TP": {"TP": 1},
159
- "FP": {"FP": 1},
160
- "FN": {"FN": 1},
161
- "LE": {"TP": 0, "FP": 0.5, "FN": 0.5},
162
- "BE": {"TP": 0.5, "FP": 0.25, "FN": 0.25},
163
- "LBE": {"TP": 0, "FP": 0.5, "FN": 0.5}}
164
- print(weights)
 
 
165
 
166
  config = {"labels": "all", "eval_method": [mode], "weights": weights,}
167
  results = calculate_results(total_errors, config)
@@ -170,34 +173,36 @@ class FairEvaluation(evaluate.Metric):
170
  # (4) SELECT OUTPUT MODE AND REFORMAT AS SEQEVAL-HUGGINGFACE OUTPUT
171
  # initialize empty dictionary and count errors
172
  output = {}
173
- total_trad_errors = results['overall']['traditional']['FP'] + results['overall']['traditional']['FN']
174
- total_fair_errors = results['overall']['fair']['FP'] + results['overall']['fair']['FN'] + \
175
- results['overall']['fair']['LE'] + results['overall']['fair']['BE'] + \
176
- results['overall']['fair']['LBE']
 
 
 
 
 
 
 
 
 
177
 
178
  # assert valid options
179
  assert mode in ['traditional', 'fair', 'weighted'], 'mode must be \'traditional\', \'fair\' or \'weighted\''
180
- assert error_format in ['count', 'proportion'], 'error_format must be \'count\' or \'proportion\''
181
 
182
  # append entity-level errors and scores
183
  if mode == 'traditional':
184
  for k, v in results['per_label'][mode].items():
185
- if error_format == 'count':
186
- output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'], 'TP': v['TP'],
187
- 'FP': v['FP'], 'FN': v['FN']}
188
- elif error_format == 'proportion':
189
- output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'], 'TP': v['TP'],
190
- 'FP': v['FP'] / total_trad_errors, 'FN': v['FN'] / total_trad_errors}
191
  elif mode == 'fair' or mode == 'weighted':
192
  for k, v in results['per_label'][mode].items():
193
- if error_format == 'count':
194
- output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'], 'TP': v['TP'],
195
- 'FP': v['FP'], 'FN': v['FN'], 'LE': v['LE'], 'BE': v['BE'], 'LBE': v['LBE']}
196
- elif error_format == 'proportion':
197
- output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'], 'TP': v['TP'],
198
- 'FP': v['FP'] / total_fair_errors, 'FN': v['FN'] / total_fair_errors,
199
- 'LE': v['LE'] / total_fair_errors, 'BE': v['BE'] / total_fair_errors,
200
- 'LBE': v['LBE'] / total_fair_errors}
201
 
202
  # append overall scores
203
  output['overall_precision'] = results['overall'][mode]['Prec']
@@ -206,25 +211,16 @@ class FairEvaluation(evaluate.Metric):
206
 
207
  # append overall error counts
208
  if mode == 'traditional':
209
- output['TP'] = results['overall'][mode]['TP']
210
- output['FP'] = results['overall'][mode]['FP']
211
- output['FN'] = results['overall'][mode]['FN']
212
- if error_format == 'proportion':
213
- output['FP'] = output['FP'] / total_trad_errors
214
- output['FN'] = output['FN'] / total_trad_errors
215
  elif mode == 'fair' or 'weighted':
216
- output['TP'] = results['overall'][mode]['TP']
217
- output['FP'] = results['overall'][mode]['FP']
218
- output['FN'] = results['overall'][mode]['FN']
219
- output['LE'] = results['overall'][mode]['LE']
220
- output['BE'] = results['overall'][mode]['BE']
221
- output['LBE'] = results['overall'][mode]['LBE']
222
- if error_format == 'proportion':
223
- output['FP'] = output['FP'] / total_fair_errors
224
- output['FN'] = output['FN'] / total_fair_errors
225
- output['LE'] = output['LE'] / total_fair_errors
226
- output['BE'] = output['BE'] / total_fair_errors
227
- output['LBE'] = output['LBE'] / total_fair_errors
228
 
229
  return output
230
 
 
147
  true_spans = seq_to_fair(true_spans)
148
  pred_spans = seq_to_fair(pred_spans)
149
 
150
+ # (3) COUNT ERRORS AND CALCULATE SCORES (counting total ground truth entities too)
151
  total_errors = compare_spans([], [])
152
+ total_ref_entities = 0
153
  for i in range(len(true_spans)):
154
+ total_ref_entities += len(true_spans[i])
155
  sentence_errors = compare_spans(true_spans[i], pred_spans[i])
156
  total_errors = add_dict(total_errors, sentence_errors)
157
 
158
  if weights is None and mode == 'weighted':
 
159
  weights = {"TP": {"TP": 1},
160
+ "FP": {"FP": 1},
161
+ "FN": {"FN": 1},
162
+ "LE": {"TP": 0, "FP": 0.5, "FN": 0.5},
163
+ "BE": {"TP": 0.5, "FP": 0.25, "FN": 0.25},
164
+ "LBE": {"TP": 0, "FP": 0.5, "FN": 0.5}}
165
+ print("The chosen mode is \'weighted\', but no weights are given. Setting weights to:\n")
166
+ for k in weights:
167
+ print(k, ":", weights[k])
168
 
169
  config = {"labels": "all", "eval_method": [mode], "weights": weights,}
170
  results = calculate_results(total_errors, config)
 
173
  # (4) SELECT OUTPUT MODE AND REFORMAT AS SEQEVAL-HUGGINGFACE OUTPUT
174
  # initialize empty dictionary and count errors
175
  output = {}
176
+ # control the divider for the error_format (count, proportion over total errors or over total entities)
177
+ if error_format == 'count':
178
+ trad_divider = 1,
179
+ fair_divider = 1,
180
+ elif error_format == 'entity_ratio':
181
+ trad_divider = total_ref_entities
182
+ fair_divider = total_ref_entities
183
+ elif error_format == 'error_ratio':
184
+ trad_divider = results['overall']['traditional']['FP'] + results['overall']['traditional']['FN']
185
+ fair_divider = results['overall']['fair']['FP'] + results['overall']['fair']['FN'] + \
186
+ results['overall']['fair']['LE'] + results['overall']['fair']['BE'] + \
187
+ results['overall']['fair']['LBE']
188
+
189
 
190
  # assert valid options
191
  assert mode in ['traditional', 'fair', 'weighted'], 'mode must be \'traditional\', \'fair\' or \'weighted\''
192
+ assert error_format in ['count', 'error_ratio', 'entity_ratio'], 'error_format must be \'count\', \'error_ratio\' or \'entity_ratio\''
193
 
194
  # append entity-level errors and scores
195
  if mode == 'traditional':
196
  for k, v in results['per_label'][mode].items():
197
+ output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'],
198
+ 'TP': v['TP'] / trad_divider if error_format == 'entity_ratio' else v['TP'],
199
+ 'FP': v['FP'] / trad_divider, 'FN': v['FN'] / trad_divider}
 
 
 
200
  elif mode == 'fair' or mode == 'weighted':
201
  for k, v in results['per_label'][mode].items():
202
+ output[k] = {'precision': v['Prec'], 'recall': v['Rec'], 'f1': v['F1'],
203
+ 'TP': v['TP'] / fair_divider if error_format == 'entity_ratio' else v['TP'],
204
+ 'FP': v['FP'] / fair_divider, 'FN': v['FN'] / fair_divider,
205
+ 'LE': v['LE'] / fair_divider, 'BE': v['BE'] / fair_divider, 'LBE': v['LBE'] / fair_divider}
 
 
 
 
206
 
207
  # append overall scores
208
  output['overall_precision'] = results['overall'][mode]['Prec']
 
211
 
212
  # append overall error counts
213
  if mode == 'traditional':
214
+ output['TP'] = results['overall'][mode]['TP'] / trad_divider if error_format == 'entity_ratio' else results['overall'][mode]['TP']
215
+ output['FP'] = results['overall'][mode]['FP'] / trad_divider
216
+ output['FN'] = results['overall'][mode]['FN'] / trad_divider
 
 
 
217
  elif mode == 'fair' or 'weighted':
218
+ output['TP'] = results['overall'][mode]['TP'] / fair_divider if error_format == 'entity_ratio' else results['overall'][mode]['TP']
219
+ output['FP'] = results['overall'][mode]['FP'] / fair_divider
220
+ output['FN'] = results['overall'][mode]['FN'] / fair_divider
221
+ output['LE'] = results['overall'][mode]['LE'] / fair_divider
222
+ output['BE'] = results['overall'][mode]['BE'] / fair_divider
223
+ output['LBE'] = results['overall'][mode]['LBE'] / fair_divider
 
 
 
 
 
 
224
 
225
  return output
226