Victoria Oberascher commited on
Commit
9b22cca
·
1 Parent(s): aab971e

implement confidence curve feature

Browse files
Files changed (1) hide show
  1. det-metrics.py +123 -48
det-metrics.py CHANGED
@@ -20,6 +20,7 @@ import evaluate
20
  import numpy as np
21
  from deprecated import deprecated
22
  from seametrics.detection import PrecisionRecallF1Support
 
23
  from seametrics.payload import Payload
24
 
25
  _CITATION = """\
@@ -91,7 +92,7 @@ Examples:
91
  >>> from seametrics.payload.processor import PayloadProcessor
92
  >>> payload = PayloadProcessor(...).payload
93
  >>> module = evaluate.load("SEA-AI/det-metrics", ...)
94
- >>> module.add_payload(payload)
95
  >>> result = module.compute()
96
  >>> print(result)
97
  {'all': {
@@ -122,20 +123,36 @@ class DetectionMetric(evaluate.Metric):
122
  class_agnostic: bool = True,
123
  bbox_format: str = "xywh",
124
  iou_type: Literal["bbox", "segm"] = "bbox",
125
- **kwargs
 
126
  ):
127
  super().__init__(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  self.coco_metric = PrecisionRecallF1Support(
129
- iou_thresholds=(
130
- iou_threshold if isinstance(iou_threshold, list) else [iou_threshold]
131
- ),
132
- area_ranges=[v for _, v in area_ranges_tuples],
133
- area_ranges_labels=[k for k, _ in area_ranges_tuples],
134
- class_agnostic=class_agnostic,
135
- iou_type=iou_type,
136
- box_format=bbox_format,
137
  )
138
 
 
 
 
139
  def _info(self):
140
  return evaluate.MetricInfo(
141
  # This is the description that will appear on the modules page.
@@ -185,29 +202,63 @@ class DetectionMetric(evaluate.Metric):
185
 
186
  self.coco_metric.update(prediction, reference)
187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  # does not impact the metric, but is required for the interface x_x
189
  super(evaluate.Metric, self).add(
190
- prediction=self._postprocess(prediction),
191
- references=self._postprocess(reference),
192
- **kwargs
193
  )
194
 
195
- @deprecated(reason="Use `module.add_payload` instead")
196
  def add_batch(self, payload: Payload, model_name: str = None):
197
  """Takes as input a payload and adds the batch to the metric"""
198
- self.add_payload(payload, model_name)
199
 
200
  def _compute(self, *, predictions, references, **kwargs):
201
  """Called within the evaluate.Metric.compute() method"""
202
- return self.coco_metric.compute()
203
 
204
- def add_payload(self, payload: Payload, model_name: str = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  """Converts the payload to the format expected by the metric"""
206
  # import only if needed since fiftyone is not a direct dependency
207
- from seametrics.detection.utils import payload_to_det_metric
208
 
209
  predictions, references = payload_to_det_metric(payload, model_name)
210
  self.add(prediction=predictions, reference=references)
 
211
  return self
212
 
213
  def _preprocess(self, list_of_dicts):
@@ -235,55 +286,79 @@ class DetectionMetric(evaluate.Metric):
235
  elif isinstance(v, list):
236
  d[k] = np.array(v)
237
  return d
238
-
239
- def compute_for_multiple_models(self, payload):
240
- results = {}
241
- for model_name in payload.models:
242
- self.add_payload(payload, model_name)
243
- results[model_name] = self._compute()
244
 
245
- return results
 
 
 
 
246
 
 
 
 
 
 
 
 
 
247
 
248
- def generate_confidence_curves(self, results, models, confidence_config = {"T":0,
249
- "R":0,
250
- "K":0,
251
- "A":0,
252
- "M":0}):
253
-
254
  import plotly.graph_objects as go
255
  from seametrics.detection.utils import get_confidence_metric_vals
256
 
257
  # Create traces
258
  fig = go.Figure()
259
- metrics = ['precision', 'recall', 'f1']
260
- for model in models:
 
261
  plot_data = get_confidence_metric_vals(
262
- cocoeval=results[model['name']]['eval'],
263
- T=confidence_config['T'],
264
- R=confidence_config['R'],
265
- K=confidence_config['K'],
266
- A=confidence_config['A'],
267
- M=confidence_config['M']
268
  )
269
 
270
  for metric in metrics:
271
  fig.add_trace(
272
  go.Scatter(
273
- x=plot_data['conf'],
274
  y=plot_data[metric],
275
- mode='lines',
276
- name=f"{model['name'].split('_')[0]} {metric}",
277
- line=dict(dash=None if metric == 'f1' else 'dash'),
278
  )
279
  )
280
 
281
  fig.update_layout(
282
  title="Metric vs Confidence",
283
- hovermode='x unified',
284
  xaxis_title="Confidence",
285
- yaxis_title="Metric value")
286
- fig.show()
287
-
288
  return fig
289
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  import numpy as np
21
  from deprecated import deprecated
22
  from seametrics.detection import PrecisionRecallF1Support
23
+ from seametrics.detection.utils import payload_to_det_metric
24
  from seametrics.payload import Payload
25
 
26
  _CITATION = """\
 
92
  >>> from seametrics.payload.processor import PayloadProcessor
93
  >>> payload = PayloadProcessor(...).payload
94
  >>> module = evaluate.load("SEA-AI/det-metrics", ...)
95
+ >>> module._add_payload(payload)
96
  >>> result = module.compute()
97
  >>> print(result)
98
  {'all': {
 
123
  class_agnostic: bool = True,
124
  bbox_format: str = "xywh",
125
  iou_type: Literal["bbox", "segm"] = "bbox",
126
+ payload: Payload = None,
127
+ **kwargs,
128
  ):
129
  super().__init__(**kwargs)
130
+
131
+ # save parameters for later
132
+ self.payload = payload
133
+ self.model_names = payload.models if payload else ["custom"]
134
+ self.iou_thresholds = (
135
+ iou_threshold if isinstance(iou_threshold, list) else [iou_threshold]
136
+ )
137
+ self.area_ranges = [v for _, v in area_ranges_tuples]
138
+ self.area_ranges_labels = [k for k, _ in area_ranges_tuples]
139
+ self.class_agnostic = class_agnostic
140
+ self.iou_type = iou_type
141
+ self.box_format = bbox_format
142
+
143
+ # initialize coco_metrics
144
  self.coco_metric = PrecisionRecallF1Support(
145
+ iou_thresholds=self.iou_thresholds,
146
+ area_ranges=self.area_ranges,
147
+ area_ranges_labels=self.area_ranges_labels,
148
+ class_agnostic=self.class_agnostic,
149
+ iou_type=self.iou_type,
150
+ box_format=self.box_format,
 
 
151
  )
152
 
153
+ # initialize evaluation metric
154
+ self._init_evaluation_metric()
155
+
156
  def _info(self):
157
  return evaluate.MetricInfo(
158
  # This is the description that will appear on the modules page.
 
202
 
203
  self.coco_metric.update(prediction, reference)
204
 
205
+ def _init_evaluation_metric(self, **kwargs):
206
+ """
207
+ Initializes the evaluation metric by generating sample data, preprocessing predictions and references,
208
+ and then adding the processed data to the metric using the super class method with additional keyword arguments.
209
+
210
+ Parameters:
211
+ **kwargs: Additional keyword arguments for the super class method.
212
+
213
+ Returns:
214
+ None
215
+ """
216
+ predictions, references = self._generate_sample_data()
217
+ predictions = self._preprocess(predictions)
218
+ references = self._preprocess(references)
219
+
220
  # does not impact the metric, but is required for the interface x_x
221
  super(evaluate.Metric, self).add(
222
+ prediction=self._postprocess(predictions),
223
+ references=self._postprocess(references),
224
+ **kwargs,
225
  )
226
 
227
+ @deprecated(reason="Use `module._add_payload` instead")
228
  def add_batch(self, payload: Payload, model_name: str = None):
229
  """Takes as input a payload and adds the batch to the metric"""
230
+ self._add_payload(payload, model_name)
231
 
232
  def _compute(self, *, predictions, references, **kwargs):
233
  """Called within the evaluate.Metric.compute() method"""
 
234
 
235
+ results = {}
236
+ for model_name in self.model_names:
237
+ print(f"\n##### {model_name} #####")
238
+ # add payload if available (otherwise predictions and references must be added with add function)
239
+ if self.payload:
240
+ self._add_payload(self.payload, model_name)
241
+
242
+ results[model_name] = self.coco_metric.compute()
243
+
244
+ # reset coco_metrics for next model
245
+ self.coco_metric = PrecisionRecallF1Support(
246
+ iou_thresholds=self.iou_thresholds,
247
+ area_ranges=self.area_ranges,
248
+ area_ranges_labels=self.area_ranges_labels,
249
+ class_agnostic=self.class_agnostic,
250
+ iou_type=self.iou_type,
251
+ box_format=self.box_format,
252
+ )
253
+ return results
254
+
255
+ def _add_payload(self, payload: Payload, model_name: str = None):
256
  """Converts the payload to the format expected by the metric"""
257
  # import only if needed since fiftyone is not a direct dependency
 
258
 
259
  predictions, references = payload_to_det_metric(payload, model_name)
260
  self.add(prediction=predictions, reference=references)
261
+
262
  return self
263
 
264
  def _preprocess(self, list_of_dicts):
 
286
  elif isinstance(v, list):
287
  d[k] = np.array(v)
288
  return d
 
 
 
 
 
 
289
 
290
+ def generate_confidence_curves(
291
+ self, results, confidence_config={"T": 0, "R": 0, "K": 0, "A": 0, "M": 0}
292
+ ):
293
+ """
294
+ Generate confidence curves based on results and confidence configuration.
295
 
296
+ Parameters:
297
+ results (dict): Results of the evaluation for different models.
298
+ confidence_config (dict): Configuration for confidence values. Defaults to {"T": 0, "R": 0, "K": 0, "A": 0, "M": 0}.
299
+ T: [1e-10] iou threshold
300
+ R: recall threshold (not used)
301
+ K: class index (class-agnostic mAP, so only 0)
302
+ A: 0=all, 1=small, 2=medium, 3=large, ... (depending on area ranges)
303
+ M: [100] maxDets default in precision_recall_f1_support
304
 
305
+ Returns:
306
+ fig (plotly.graph_objects.Figure): The plotly figure showing the confidence curves.
307
+ """
 
 
 
308
  import plotly.graph_objects as go
309
  from seametrics.detection.utils import get_confidence_metric_vals
310
 
311
  # Create traces
312
  fig = go.Figure()
313
+ metrics = ["precision", "recall", "f1"]
314
+ for model_name in self.model_names:
315
+ print(f"##### {model_name} #####")
316
  plot_data = get_confidence_metric_vals(
317
+ cocoeval=results[model_name]["eval"],
318
+ T=confidence_config["T"],
319
+ R=confidence_config["R"],
320
+ K=confidence_config["K"],
321
+ A=confidence_config["A"],
322
+ M=confidence_config["M"],
323
  )
324
 
325
  for metric in metrics:
326
  fig.add_trace(
327
  go.Scatter(
328
+ x=plot_data["conf"],
329
  y=plot_data[metric],
330
+ mode="lines",
331
+ name=f"{model_name} {metric}",
332
+ line=dict(dash=None if metric == "f1" else "dash"),
333
  )
334
  )
335
 
336
  fig.update_layout(
337
  title="Metric vs Confidence",
338
+ hovermode="x unified",
339
  xaxis_title="Confidence",
340
+ yaxis_title="Metric value",
341
+ )
 
342
  return fig
343
+
344
+ def _generate_sample_data(self):
345
+ """
346
+ Generates dummy sample data for predictions and references used for initialization.
347
+
348
+ Returns:
349
+ Tuple[List[Dict[str, List[Union[float, int]]]], List[Dict[str, List[Union[float, int]]]]]:
350
+ - predictions (List[Dict[str, List[Union[float, int]]]]): A list of dictionaries representing the predictions. Each dictionary contains the following keys:
351
+ - boxes (List[List[float]]): A list of bounding boxes in the format [x, y, w, h].
352
+ - labels (List[int]): A list of labels.
353
+ - scores (List[float]): A list of scores.
354
+ - references (List[Dict[str, List[Union[float, int]]]]): A list of dictionaries representing the references. Each dictionary contains the following keys:
355
+ - boxes (List[List[float]]): A list of bounding boxes in the format [x, y, w, h].
356
+ - labels (List[int]): A list of labels.
357
+ - area (List[float]): A list of areas.
358
+ """
359
+ predictions = [
360
+ {"boxes": [[1.0, 2.0, 3.0, 4.0]], "labels": [0], "scores": [1.0]}
361
+ ]
362
+ references = [{"boxes": [[1.0, 2.0, 3.0, 4.0]], "labels": [0], "area": [1.0]}]
363
+
364
+ return predictions, references