franzi2505 commited on
Commit
de1f95e
·
1 Parent(s): 0bb094e

fix bugs and add improvements

Browse files
Files changed (1) hide show
  1. PanopticQuality.py +28 -8
PanopticQuality.py CHANGED
@@ -71,7 +71,6 @@ Examples:
71
  Added data ...
72
  Start computing ...
73
  Finished!
74
- tensor(0.2082, dtype=torch.float64)
75
  """
76
 
77
 
@@ -81,6 +80,8 @@ class PQMetric(evaluate.Metric):
81
  self,
82
  label2id: Dict[str, int] = None,
83
  stuff: List[str] = None,
 
 
84
  **kwargs
85
  ):
86
  super().__init__(**kwargs)
@@ -109,9 +110,13 @@ class PQMetric(evaluate.Metric):
109
 
110
  self.label2id = label2id if label2id is not None else DEFAULT_LABEL2ID
111
  self.stuff = stuff if stuff is not None else DEFAULT_STUFF
 
 
112
  self.pq_metric = PanopticQuality(
113
  things=set([self.label2id[label] for label in self.label2id.keys() if label not in self.stuff]),
114
- stuffs=set([self.label2id[label] for label in self.label2id.keys() if label in self.stuff])
 
 
115
  )
116
 
117
  def _info(self):
@@ -151,9 +156,6 @@ class PQMetric(evaluate.Metric):
151
  # in case the inputs are lists, convert them to numpy arrays
152
 
153
  self.pq_metric.update(prediction, reference)
154
- print("TP:", self.pq_metric.metric.true_positives)
155
- print("FP:", self.pq_metric.metric.false_positives)
156
- print("FN:", self.pq_metric.metric.false_negatives)
157
 
158
  # does not impact the metric, but is required for the interface x_x
159
  super(evaluate.Metric, self).add(
@@ -164,12 +166,30 @@ class PQMetric(evaluate.Metric):
164
 
165
  def _compute(self, *, predictions, references, **kwargs):
166
  """Called within the evaluate.Metric.compute() method"""
167
- result = self.pq_metric.compute() # n_classes (sorted things + sorted stuffs), (pq, sq, rq)
 
 
 
 
168
  id2label = {id: label for label, id in self.label2id.items()}
169
- return {
170
- id2label[numeric_label]: result[i] for i, numeric_label in enumerate(self.pq_metric.things + self.pq_metric.stuffs)
 
 
 
 
 
 
 
171
  }
172
 
 
 
 
 
 
 
 
173
  def add_payload(self, payload: Payload, model_name: str = None):
174
  """Converts the payload to the format expected by the metric"""
175
  # import only if needed since fiftyone is not a direct dependency
 
71
  Added data ...
72
  Start computing ...
73
  Finished!
 
74
  """
75
 
76
 
 
80
  self,
81
  label2id: Dict[str, int] = None,
82
  stuff: List[str] = None,
83
+ per_class: bool = True,
84
+ split_sq_rq: bool = True,
85
  **kwargs
86
  ):
87
  super().__init__(**kwargs)
 
110
 
111
  self.label2id = label2id if label2id is not None else DEFAULT_LABEL2ID
112
  self.stuff = stuff if stuff is not None else DEFAULT_STUFF
113
+ self.per_class = per_class
114
+ self.split_sq_rq = split_sq_rq
115
  self.pq_metric = PanopticQuality(
116
  things=set([self.label2id[label] for label in self.label2id.keys() if label not in self.stuff]),
117
+ stuffs=set([self.label2id[label] for label in self.label2id.keys() if label in self.stuff]),
118
+ return_per_class=per_class,
119
+ return_sq_and_rq=split_sq_rq
120
  )
121
 
122
  def _info(self):
 
156
  # in case the inputs are lists, convert them to numpy arrays
157
 
158
  self.pq_metric.update(prediction, reference)
 
 
 
159
 
160
  # does not impact the metric, but is required for the interface x_x
161
  super(evaluate.Metric, self).add(
 
166
 
167
  def _compute(self, *, predictions, references, **kwargs):
168
  """Called within the evaluate.Metric.compute() method"""
169
+ tp = self.pq_metric.metric.true_positives.clone()
170
+ fp = self.pq_metric.metric.false_positives.clone()
171
+ fn = self.pq_metric.metric.false_negatives.clone()
172
+ iou = self.pq_metric.metric.iou_sum.clone()
173
+
174
  id2label = {id: label for label, id in self.label2id.items()}
175
+ things_stuffs = sorted(self.pq_metric.things) + sorted(self.pq_metric.stuffs)
176
+
177
+ # compute scores
178
+ result = self.pq_metric.compute() # shape : (n_classes (sorted things + sorted stuffs), scores (pq, sq, rq))
179
+
180
+ result_dict = {
181
+ "numbers": {id2label[numeric_label]: [tp[i].item(), fp[i].item(), fn[i].item(), iou[i].item()] \
182
+ for i, numeric_label in enumerate(things_stuffs)},
183
+ "scores": None
184
  }
185
 
186
+ if self.per_class:
187
+ result_dict["scores"] = {id2label[numeric_label]: result[i].tolist() for i, numeric_label in enumerate(things_stuffs)}
188
+ else:
189
+ result_dict["scores"] = result.tolist()
190
+
191
+ return result_dict
192
+
193
  def add_payload(self, payload: Payload, model_name: str = None):
194
  """Converts the payload to the format expected by the metric"""
195
  # import only if needed since fiftyone is not a direct dependency