Spaces:
Sleeping
Sleeping
Commit
·
de1f95e
1
Parent(s):
0bb094e
fix bugs and add improvements
Browse files- PanopticQuality.py +28 -8
PanopticQuality.py
CHANGED
@@ -71,7 +71,6 @@ Examples:
|
|
71 |
Added data ...
|
72 |
Start computing ...
|
73 |
Finished!
|
74 |
-
tensor(0.2082, dtype=torch.float64)
|
75 |
"""
|
76 |
|
77 |
|
@@ -81,6 +80,8 @@ class PQMetric(evaluate.Metric):
|
|
81 |
self,
|
82 |
label2id: Dict[str, int] = None,
|
83 |
stuff: List[str] = None,
|
|
|
|
|
84 |
**kwargs
|
85 |
):
|
86 |
super().__init__(**kwargs)
|
@@ -109,9 +110,13 @@ class PQMetric(evaluate.Metric):
|
|
109 |
|
110 |
self.label2id = label2id if label2id is not None else DEFAULT_LABEL2ID
|
111 |
self.stuff = stuff if stuff is not None else DEFAULT_STUFF
|
|
|
|
|
112 |
self.pq_metric = PanopticQuality(
|
113 |
things=set([self.label2id[label] for label in self.label2id.keys() if label not in self.stuff]),
|
114 |
-
stuffs=set([self.label2id[label] for label in self.label2id.keys() if label in self.stuff])
|
|
|
|
|
115 |
)
|
116 |
|
117 |
def _info(self):
|
@@ -151,9 +156,6 @@ class PQMetric(evaluate.Metric):
|
|
151 |
# in case the inputs are lists, convert them to numpy arrays
|
152 |
|
153 |
self.pq_metric.update(prediction, reference)
|
154 |
-
print("TP:", self.pq_metric.metric.true_positives)
|
155 |
-
print("FP:", self.pq_metric.metric.false_positives)
|
156 |
-
print("FN:", self.pq_metric.metric.false_negatives)
|
157 |
|
158 |
# does not impact the metric, but is required for the interface x_x
|
159 |
super(evaluate.Metric, self).add(
|
@@ -164,12 +166,30 @@ class PQMetric(evaluate.Metric):
|
|
164 |
|
165 |
def _compute(self, *, predictions, references, **kwargs):
|
166 |
"""Called within the evaluate.Metric.compute() method"""
|
167 |
-
|
|
|
|
|
|
|
|
|
168 |
id2label = {id: label for label, id in self.label2id.items()}
|
169 |
-
|
170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
}
|
172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
def add_payload(self, payload: Payload, model_name: str = None):
|
174 |
"""Converts the payload to the format expected by the metric"""
|
175 |
# import only if needed since fiftyone is not a direct dependency
|
|
|
71 |
Added data ...
|
72 |
Start computing ...
|
73 |
Finished!
|
|
|
74 |
"""
|
75 |
|
76 |
|
|
|
80 |
self,
|
81 |
label2id: Dict[str, int] = None,
|
82 |
stuff: List[str] = None,
|
83 |
+
per_class: bool = True,
|
84 |
+
split_sq_rq: bool = True,
|
85 |
**kwargs
|
86 |
):
|
87 |
super().__init__(**kwargs)
|
|
|
110 |
|
111 |
self.label2id = label2id if label2id is not None else DEFAULT_LABEL2ID
|
112 |
self.stuff = stuff if stuff is not None else DEFAULT_STUFF
|
113 |
+
self.per_class = per_class
|
114 |
+
self.split_sq_rq = split_sq_rq
|
115 |
self.pq_metric = PanopticQuality(
|
116 |
things=set([self.label2id[label] for label in self.label2id.keys() if label not in self.stuff]),
|
117 |
+
stuffs=set([self.label2id[label] for label in self.label2id.keys() if label in self.stuff]),
|
118 |
+
return_per_class=per_class,
|
119 |
+
return_sq_and_rq=split_sq_rq
|
120 |
)
|
121 |
|
122 |
def _info(self):
|
|
|
156 |
# in case the inputs are lists, convert them to numpy arrays
|
157 |
|
158 |
self.pq_metric.update(prediction, reference)
|
|
|
|
|
|
|
159 |
|
160 |
# does not impact the metric, but is required for the interface x_x
|
161 |
super(evaluate.Metric, self).add(
|
|
|
166 |
|
167 |
def _compute(self, *, predictions, references, **kwargs):
|
168 |
"""Called within the evaluate.Metric.compute() method"""
|
169 |
+
tp = self.pq_metric.metric.true_positives.clone()
|
170 |
+
fp = self.pq_metric.metric.false_positives.clone()
|
171 |
+
fn = self.pq_metric.metric.false_negatives.clone()
|
172 |
+
iou = self.pq_metric.metric.iou_sum.clone()
|
173 |
+
|
174 |
id2label = {id: label for label, id in self.label2id.items()}
|
175 |
+
things_stuffs = sorted(self.pq_metric.things) + sorted(self.pq_metric.stuffs)
|
176 |
+
|
177 |
+
# compute scores
|
178 |
+
result = self.pq_metric.compute() # shape : (n_classes (sorted things + sorted stuffs), scores (pq, sq, rq))
|
179 |
+
|
180 |
+
result_dict = {
|
181 |
+
"numbers": {id2label[numeric_label]: [tp[i].item(), fp[i].item(), fn[i].item(), iou[i].item()] \
|
182 |
+
for i, numeric_label in enumerate(things_stuffs)},
|
183 |
+
"scores": None
|
184 |
}
|
185 |
|
186 |
+
if self.per_class:
|
187 |
+
result_dict["scores"] = {id2label[numeric_label]: result[i].tolist() for i, numeric_label in enumerate(things_stuffs)}
|
188 |
+
else:
|
189 |
+
result_dict["scores"] = result.tolist()
|
190 |
+
|
191 |
+
return result_dict
|
192 |
+
|
193 |
def add_payload(self, payload: Payload, model_name: str = None):
|
194 |
"""Converts the payload to the format expected by the metric"""
|
195 |
# import only if needed since fiftyone is not a direct dependency
|