panoptic-quality / PanopticQuality.py
franzi2505's picture
add option to change method
6790ab3
raw
history blame
11.2 kB
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TODO: Add a description here."""
from typing import Dict, List, Tuple, Literal
import evaluate
import datasets
import numpy as np
from seametrics.panoptic import PanopticQuality
from seametrics.payload import Payload
_CITATION = """\
@inproceedings{DBLP:conf/cvpr/KirillovHGRD19,
author = {Alexander Kirillov and
Kaiming He and
Ross B. Girshick and
Carsten Rother and
Piotr Doll{\'{a}}r},
title = {Panoptic Segmentation},
booktitle = {{IEEE} Conference on Computer Vision and Pattern Recognition, {CVPR}
2019, Long Beach, CA, USA, June 16-20, 2019},
pages = {9404--9413},
publisher = {Computer Vision Foundation / {IEEE}},
year = {2019},
url = {http://openaccess.thecvf.com/content\_CVPR\_2019/html/Kirillov\_Panoptic\_Segmentation\_CVPR\_2019\_paper.html
}
"""
_DESCRIPTION = """\
This evaluation metric calculates Panoptic Quality (PQ) for panoptic segmentation masks.
"""
_KWARGS_DESCRIPTION = """
Calculates PQ-score given predicted and ground truth panoptic segmentation masks.
Args:
predictions: a 4-d array of shape (batch_size, img_height, img_width, 2).
The last dimension should hold the category index at position 0, and
the instance ID at position 1.
references: a 4-d array of shape (batch_size, img_height, img_width, 2).
The last dimension should hold the category index at position 0, and
the instance ID at position 1.
Returns:
A dictionary containing PQ, RQ, SQ scores for key "scores" and
FP, TP, FN, IOU sum numbers for key "numbers".
Examples:
>>> import evaluate
>>> from seametrics.payload.processor import PayloadProcessor
>>> MODEL_FIELD = ["maskformer-27k-100ep"]
>>> payload = PayloadProcessor("SAILING_PANOPTIC_DATASET_QA",
>>> gt_field="ground_truth_det",
>>> models=MODEL_FIELD,
>>> sequence_list=["Trip_55_Seq_2", "Trip_197_Seq_1", "Trip_197_Seq_68"],
>>> excluded_classes=[""]).payload
>>> module = evaluate.load("SEA-AI/PanopticQuality")
>>> module.add_payload(payload, model_name=MODEL_FIELD[0])
>>> module.compute()
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 3/3 [00:03<00:00, 1.30s/it]
Added data ...
Start computing ...
Finished!
"""
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class PQMetric(evaluate.Metric):
def __init__(
self,
label2id: Dict[str, int] = None,
stuff: List[str] = None,
per_class: bool = True,
split_sq_rq: bool = True,
area_rng: List[Tuple[float]] = [(0, 1e5**2),
(0**2, 6**2),
(6**2, 12**2),
(12**2, 1e5**2)],
method: Literal["iou", "hungarian"] = "hungarian",
device: str = None,
class_agnostic: bool = False,
**kwargs
):
super().__init__(**kwargs)
if class_agnostic:
DEFAULT_LABEL2ID = {'WATER': 0,
'SKY': 1,
'LAND': 2,
'MOTORBOAT': 3,
'FAR_AWAY_OBJECT': 3,
'SAILING_BOAT_WITH_CLOSED_SAILS': 3,
'SHIP': 3,
'WATERCRAFT': 3,
'SPHERICAL_BUOY': 3,
'CONSTRUCTION': 4,
'FLOTSAM': 3,
'SAILING_BOAT_WITH_OPEN_SAILS': 3,
'CONTAINER': 3,
'PILLAR_BUOY': 3,
'AERIAL_ANIMAL': 3,
'HUMAN_IN_WATER': 3,
'OWN_BOAT': 5,
'WOODEN_LOG': 3,
'MARITIME_ANIMAL': 3}
else:
DEFAULT_LABEL2ID = {'WATER': 0,
'SKY': 1,
'LAND': 2,
'MOTORBOAT': 3,
'FAR_AWAY_OBJECT': 4,
'SAILING_BOAT_WITH_CLOSED_SAILS': 5,
'SHIP': 6,
'WATERCRAFT': 7,
'SPHERICAL_BUOY': 8,
'CONSTRUCTION': 9,
'FLOTSAM': 10,
'SAILING_BOAT_WITH_OPEN_SAILS': 11,
'CONTAINER': 12,
'PILLAR_BUOY': 13,
'AERIAL_ANIMAL': 14,
'HUMAN_IN_WATER': 15,
'OWN_BOAT': 16,
'WOODEN_LOG': 17,
'MARITIME_ANIMAL': 18}
DEFAULT_STUFF = ["WATER", "SKY", "LAND", "CONSTRUCTION", "ICE", "OWN_BOAT"]
self.label2id = label2id if label2id is not None else DEFAULT_LABEL2ID
self.id2label = {id: label for label, id in self.label2id.items()}
self.stuff = stuff if stuff is not None else DEFAULT_STUFF
self.per_class = per_class
self.split_sq_rq = split_sq_rq
self.pq_metric = PanopticQuality(
things=set([self.label2id[label] for label in self.label2id.keys() if label not in self.stuff]),
stuffs=set([self.label2id[label] for label in self.label2id.keys() if label in self.stuff]),
return_per_class=per_class,
return_sq_and_rq=split_sq_rq,
areas=area_rng,
device=device,
method=method
)
self.cont_to_cat = {label:key for key, label in self.pq_metric.metric.cat_id_to_continuous_id.items()}
def _info(self):
return evaluate.MetricInfo(
# This is the description that will appear on the modules page.
module_type="metric",
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
# This defines the format of each prediction and reference
features=datasets.Features(
{
"predictions": datasets.Sequence(
datasets.Sequence(
datasets.Sequence(
datasets.Sequence(datasets.Value("float"))
)
),
),
"references": datasets.Sequence( # batch
datasets.Sequence( # img height
datasets.Sequence( # img width
datasets.Sequence(datasets.Value("float")) # 2
)
),
),
}
),
# Additional links to the codebase or references
codebase_urls=[
"https://lightning.ai/docs/torchmetrics/stable/detection/panoptic_quality.html"
],
)
def add(self, *, prediction, reference, **kwargs):
"""Adds a batch of predictions and references to the metric"""
# in case the inputs are lists, convert them to numpy arrays
self.pq_metric.update(prediction, reference)
# does not impact the metric, but is required for the interface x_x
super(evaluate.Metric, self).add(
prediction=self._postprocess(prediction),
references=self._postprocess(reference),
**kwargs
)
def _compute(self, *, predictions, references, **kwargs):
"""Called within the evaluate.Metric.compute() method"""
tp = self.pq_metric.metric.true_positives.clone().cpu() # shape : (area_rngs, n_classes (sorted things + sorted stuffs))
fp = self.pq_metric.metric.false_positives.clone().cpu()
fn = self.pq_metric.metric.false_negatives.clone().cpu()
iou = self.pq_metric.metric.iou_sum.clone().cpu()
# compute scores
result = self.pq_metric.compute().cpu() # shape : (area_rngs, n_classes (sorted things + sorted stuffs), scores (pq, sq, rq))
result_dict = dict()
if self.per_class:
if not self.split_sq_rq:
result = result.unsqueeze(0)
result_dict["scores"] = {self.id2label[numeric_label]: result[:,:, i].numpy() \
for i, numeric_label in self.cont_to_cat.items()}
result_dict["scores"].update({"ALL": result.mean(dim=-1).numpy()})
result_dict["numbers"] = {self.id2label[numeric_label]: np.stack([tp[:, i].numpy(), fp[:, i].numpy(), fn[:, i].numpy(), iou[:, i].numpy()])\
for i, numeric_label in self.cont_to_cat.items()}
result_dict["numbers"].update({"ALL": np.stack([tp.sum(dim=1).numpy(), fp.sum(dim=1).numpy(), fn.sum(dim=1).numpy(), iou.sum(dim=1).numpy()])})
else:
result_dict["scores"] = {"ALL": result.numpy() if self.split_sq_rq else (result.numpy()[np.newaxis, ...] if len(self.pq_metric.get_areas())>1 else result.numpy()[np.newaxis, np.newaxis, ...])}
result_dict["numbers"] = {"ALL": np.stack([tp.sum(dim=-1).numpy(), fp.sum(dim=-1).numpy(), fn.sum(dim=-1).numpy(), iou.sum(dim=-1).numpy()])}
return result_dict
def add_payload(self, payload: Payload, model_name: str = None):
"""Converts the payload to the format expected by the metric"""
# import only if needed since fiftyone is not a direct dependency
from seametrics.panoptic.utils import payload_to_seg_metric
predictions, references, label2id = payload_to_seg_metric(payload, model_name, self.label2id)
self.label2id = label2id
self.add(prediction=predictions, reference=references)
def _postprocess(self, np_array):
"""Converts the numpy arrays to lists for type checking"""
# add fake data to avoid out of memory problem
# only reuqired for interface, not used by metric anyway
return np.zeros((1,1,1,1)).tolist()