nbansal commited on
Commit
9db3d74
·
1 Parent(s): fdd202a

Refactor SemF1 aggregation logic and fix typo in comment

Browse files
Files changed (2) hide show
  1. semf1.py +4 -7
  2. tests.py +1 -0
semf1.py CHANGED
@@ -422,8 +422,8 @@ class SemF1(evaluate.Metric):
422
  recall_scores = [np.clip(r_scores, 0.0, 1.0).item() for (r_scores, _) in recall_scores]
423
 
424
  results.append(Scores(precision, recall_scores))
425
-
426
- # runn aggregation procedure
427
  if aggregate:
428
  mean_prec = np.mean(
429
  [score.precision for score in results]
@@ -432,12 +432,9 @@ class SemF1(evaluate.Metric):
432
  [np.array(score.recall) for score in results]
433
  ))
434
  aggregated_score = Scores(
435
- float(mean_prec),
436
  [float(mean_recall)]
437
  )
438
- aggregated_score.f1 = float(np.mean(
439
- [score.f1 for score in results]
440
- ))
441
  results = aggregated_score
442
 
443
- return results
 
422
  recall_scores = [np.clip(r_scores, 0.0, 1.0).item() for (r_scores, _) in recall_scores]
423
 
424
  results.append(Scores(precision, recall_scores))
425
+
426
+ # run aggregation procedure
427
  if aggregate:
428
  mean_prec = np.mean(
429
  [score.precision for score in results]
 
432
  [np.array(score.recall) for score in results]
433
  ))
434
  aggregated_score = Scores(
435
+ float(mean_prec),
436
  [float(mean_recall)]
437
  )
 
 
 
438
  results = aggregated_score
439
 
440
+ return results
tests.py CHANGED
@@ -708,5 +708,6 @@ class TestValidateInputFormat(unittest.TestCase):
708
  def run_tests():
709
  unittest.main(verbosity=2)
710
 
 
711
  if __name__ == '__main__':
712
  run_tests()
 
708
  def run_tests():
709
  unittest.main(verbosity=2)
710
 
711
+
712
  if __name__ == '__main__':
713
  run_tests()