jarod0411 commited on
Commit
545e8a6
·
1 Parent(s): a28f1ed
Files changed (1) hide show
  1. aucpr.py +11 -5
aucpr.py CHANGED
@@ -15,6 +15,7 @@
15
 
16
  import evaluate
17
  import datasets
 
18
 
19
 
20
  # TODO: Add BibTeX citation
@@ -70,9 +71,13 @@ class AUCPR(evaluate.Metric):
70
  citation=_CITATION,
71
  inputs_description=_KWARGS_DESCRIPTION,
72
  # This defines the format of each prediction and reference
 
 
 
 
73
  features=datasets.Features({
74
- 'predictions': datasets.Value('int64'),
75
- 'references': datasets.Value('int64'),
76
  }),
77
  # Homepage of the module for documentation
78
  homepage="http://module.homepage",
@@ -86,10 +91,11 @@ class AUCPR(evaluate.Metric):
86
  # TODO: Download external resources if needed
87
  pass
88
 
89
- def _compute(self, predictions, references):
90
  """Returns the scores"""
91
  # TODO: Compute the different scores of the module
92
- accuracy = sum(i == j for i, j in zip(predictions, references)) / len(predictions)
 
93
  return {
94
- "accuracy": accuracy,
95
  }
 
15
 
16
  import evaluate
17
  import datasets
18
+ from sklearn.metrics import precision_recall_curve, auc
19
 
20
 
21
  # TODO: Add BibTeX citation
 
71
  citation=_CITATION,
72
  inputs_description=_KWARGS_DESCRIPTION,
73
  # This defines the format of each prediction and reference
74
+ # features=datasets.Features({
75
+ # 'predictions': datasets.Value('int64'),
76
+ # 'references': datasets.Value('int64'),
77
+ # }),
78
  features=datasets.Features({
79
+ "prediction_scores": datasets.Value("float"),
80
+ "references": datasets.Value("int32"),
81
  }),
82
  # Homepage of the module for documentation
83
  homepage="http://module.homepage",
 
91
  # TODO: Download external resources if needed
92
  pass
93
 
94
+ def _compute(self, references, prediction_scores):
95
  """Returns the scores"""
96
  # TODO: Compute the different scores of the module
97
+ precision, recall, _ = precision_recall_curve(references, prediction_scores)
98
+ aucpr = auc(recall, precision)
99
  return {
100
+ "aucpr": aucpr,
101
  }