Victoria Oberascher commited on
Commit
003e48f
·
1 Parent(s): 35fe85d

add function to generate confidence curves

Browse files
Files changed (2) hide show
  1. det-metrics.py +56 -5
  2. requirements.txt +2 -1
det-metrics.py CHANGED
@@ -13,13 +13,12 @@
13
  # limitations under the License.
14
  """TODO: Add a description here."""
15
 
16
- from typing import List, Tuple, Literal
17
- from deprecated import deprecated
18
 
19
- import evaluate
20
  import datasets
 
21
  import numpy as np
22
-
23
  from seametrics.detection import PrecisionRecallF1Support
24
  from seametrics.payload import Payload
25
 
@@ -200,7 +199,7 @@ class DetectionMetric(evaluate.Metric):
200
 
201
  def _compute(self, *, predictions, references, **kwargs):
202
  """Called within the evaluate.Metric.compute() method"""
203
- return self.coco_metric.compute()["metrics"]
204
 
205
  def add_payload(self, payload: Payload, model_name: str = None):
206
  """Converts the payload to the format expected by the metric"""
@@ -236,3 +235,55 @@ class DetectionMetric(evaluate.Metric):
236
  elif isinstance(v, list):
237
  d[k] = np.array(v)
238
  return d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # limitations under the License.
14
  """TODO: Add a description here."""
15
 
16
+ from typing import List, Literal, Tuple
 
17
 
 
18
  import datasets
19
+ import evaluate
20
  import numpy as np
21
+ from deprecated import deprecated
22
  from seametrics.detection import PrecisionRecallF1Support
23
  from seametrics.payload import Payload
24
 
 
199
 
200
  def _compute(self, *, predictions, references, **kwargs):
201
  """Called within the evaluate.Metric.compute() method"""
202
+ return self.coco_metric.compute()
203
 
204
  def add_payload(self, payload: Payload, model_name: str = None):
205
  """Converts the payload to the format expected by the metric"""
 
235
  elif isinstance(v, list):
236
  d[k] = np.array(v)
237
  return d
238
+
239
+ def compute_for_multiple_models(self, payload):
240
+ results = {}
241
+ for model_name in payload.models:
242
+ self.add_payload(payload, model_name)
243
+ results[model_name] = self._compute()
244
+
245
+ return results
246
+
247
+
248
+ def generate_confidence_curves(self, results, models, confidence_config = {"T":0,
249
+ "R":0,
250
+ "K":0,
251
+ "A":0,
252
+ "M":0}):
253
+
254
+ import plotly.graph_objects as go
255
+ from seametrics.detection.utils import get_confidence_metric_vals
256
+
257
+ # Create traces
258
+ fig = go.Figure()
259
+ metrics = ['precision', 'recall', 'f1']
260
+ for model in models:
261
+ plot_data = get_confidence_metric_vals(
262
+ cocoeval=results[model['name']]['eval'],
263
+ T=confidence_config['T'],
264
+ R=confidence_config['R'],
265
+ K=confidence_config['K'],
266
+ A=confidence_config['A'],
267
+ M=confidence_config['M']
268
+ )
269
+
270
+ for metric in metrics:
271
+ fig.add_trace(
272
+ go.Scatter(
273
+ x=plot_data['conf'],
274
+ y=plot_data[metric],
275
+ mode='lines',
276
+ name=f"{model['name'].split('_')[0]} {metric}",
277
+ line=dict(dash=None if metric == 'f1' else 'dash'),
278
+ )
279
+ )
280
+
281
+ fig.update_layout(
282
+ title="Metric vs Confidence",
283
+ hovermode='x unified',
284
+ xaxis_title="Confidence",
285
+ yaxis_title="Metric value")
286
+ fig.show()
287
+
288
+ return fig
289
+
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  git+https://github.com/huggingface/evaluate@main
2
  git+https://github.com/SEA-AI/seametrics@develop
3
- fiftyone
 
 
1
  git+https://github.com/huggingface/evaluate@main
2
  git+https://github.com/SEA-AI/seametrics@develop
3
+ fiftyone
4
+ plotly