|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Implementation of the Panoptic Quality metric. |
|
|
|
Panoptic Quality is an instance-based metric for evaluating the task of |
|
image parsing, aka panoptic segmentation. |
|
|
|
Please see the paper for details: |
|
"Panoptic Segmentation", Alexander Kirillov, Kaiming He, Ross Girshick, |
|
Carsten Rother and Piotr Dollar. arXiv:1801.00868, 2018. |
|
""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import collections |
|
import numpy as np |
|
import prettytable |
|
import six |
|
|
|
from deeplab.evaluation import base_metric |
|
|
|
|
|
def _ids_to_counts(id_array): |
|
"""Given a numpy array, a mapping from each unique entry to its count.""" |
|
ids, counts = np.unique(id_array, return_counts=True) |
|
return dict(six.moves.zip(ids, counts)) |
|
|
|
|
|
class PanopticQuality(base_metric.SegmentationMetric): |
|
"""Metric class for Panoptic Quality. |
|
|
|
"Panoptic Segmentation" by Alexander Kirillov, Kaiming He, Ross Girshick, |
|
Carsten Rother, Piotr Dollar. |
|
https://arxiv.org/abs/1801.00868 |
|
""" |
|
|
|
def compare_and_accumulate( |
|
self, groundtruth_category_array, groundtruth_instance_array, |
|
predicted_category_array, predicted_instance_array): |
|
"""See base class.""" |
|
|
|
|
|
pred_segment_id = self._naively_combine_labels(predicted_category_array, |
|
predicted_instance_array) |
|
gt_segment_id = self._naively_combine_labels(groundtruth_category_array, |
|
groundtruth_instance_array) |
|
|
|
|
|
gt_segment_areas = _ids_to_counts(gt_segment_id) |
|
pred_segment_areas = _ids_to_counts(pred_segment_id) |
|
|
|
|
|
void_segment_id = self.ignored_label * self.max_instances_per_category |
|
|
|
|
|
|
|
|
|
ignored_segment_ids = { |
|
gt_segment_id for gt_segment_id in six.iterkeys(gt_segment_areas) |
|
if (gt_segment_id // |
|
self.max_instances_per_category) == self.ignored_label |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
intersection_id_array = ( |
|
gt_segment_id.astype(np.uint32) * self.offset + |
|
pred_segment_id.astype(np.uint32)) |
|
|
|
|
|
|
|
|
|
intersection_areas = _ids_to_counts(intersection_id_array) |
|
|
|
|
|
|
|
def prediction_void_overlap(pred_segment_id): |
|
void_intersection_id = void_segment_id * self.offset + pred_segment_id |
|
return intersection_areas.get(void_intersection_id, 0) |
|
|
|
|
|
def prediction_ignored_overlap(pred_segment_id): |
|
total_ignored_overlap = 0 |
|
for ignored_segment_id in ignored_segment_ids: |
|
intersection_id = ignored_segment_id * self.offset + pred_segment_id |
|
total_ignored_overlap += intersection_areas.get(intersection_id, 0) |
|
return total_ignored_overlap |
|
|
|
|
|
|
|
|
|
gt_matched = set() |
|
pred_matched = set() |
|
|
|
|
|
for intersection_id, intersection_area in six.iteritems(intersection_areas): |
|
gt_segment_id = intersection_id // self.offset |
|
pred_segment_id = intersection_id % self.offset |
|
|
|
gt_category = gt_segment_id // self.max_instances_per_category |
|
pred_category = pred_segment_id // self.max_instances_per_category |
|
if gt_category != pred_category: |
|
continue |
|
|
|
|
|
|
|
|
|
union = ( |
|
gt_segment_areas[gt_segment_id] + |
|
pred_segment_areas[pred_segment_id] - intersection_area - |
|
prediction_void_overlap(pred_segment_id)) |
|
iou = intersection_area / union |
|
if iou > 0.5: |
|
self.tp_per_class[gt_category] += 1 |
|
self.iou_per_class[gt_category] += iou |
|
gt_matched.add(gt_segment_id) |
|
pred_matched.add(pred_segment_id) |
|
|
|
|
|
for gt_segment_id in six.iterkeys(gt_segment_areas): |
|
if gt_segment_id in gt_matched: |
|
continue |
|
category = gt_segment_id // self.max_instances_per_category |
|
|
|
if category == self.ignored_label: |
|
continue |
|
self.fn_per_class[category] += 1 |
|
|
|
|
|
for pred_segment_id in six.iterkeys(pred_segment_areas): |
|
if pred_segment_id in pred_matched: |
|
continue |
|
|
|
|
|
if (prediction_ignored_overlap(pred_segment_id) / |
|
pred_segment_areas[pred_segment_id]) > 0.5: |
|
continue |
|
category = pred_segment_id // self.max_instances_per_category |
|
self.fp_per_class[category] += 1 |
|
|
|
return self.result() |
|
|
|
def _valid_categories(self): |
|
"""Categories with a "valid" value for the metric, have > 0 instances. |
|
|
|
We will ignore the `ignore_label` class and other classes which have |
|
`tp + fn + fp = 0`. |
|
|
|
Returns: |
|
Boolean array of shape `[num_categories]`. |
|
""" |
|
valid_categories = np.not_equal( |
|
self.tp_per_class + self.fn_per_class + self.fp_per_class, 0) |
|
if self.ignored_label >= 0 and self.ignored_label < self.num_categories: |
|
valid_categories[self.ignored_label] = False |
|
return valid_categories |
|
|
|
def detailed_results(self, is_thing=None): |
|
"""See base class.""" |
|
valid_categories = self._valid_categories() |
|
|
|
|
|
category_sets = collections.OrderedDict() |
|
category_sets['All'] = valid_categories |
|
if is_thing is not None: |
|
category_sets['Things'] = np.logical_and(valid_categories, is_thing) |
|
category_sets['Stuff'] = np.logical_and(valid_categories, |
|
np.logical_not(is_thing)) |
|
|
|
|
|
sq = base_metric.realdiv_maybe_zero(self.iou_per_class, self.tp_per_class) |
|
rq = base_metric.realdiv_maybe_zero( |
|
self.tp_per_class, |
|
self.tp_per_class + 0.5 * self.fn_per_class + 0.5 * self.fp_per_class) |
|
pq = np.multiply(sq, rq) |
|
|
|
|
|
results = {} |
|
for category_set_name, in_category_set in six.iteritems(category_sets): |
|
if np.any(in_category_set): |
|
results[category_set_name] = { |
|
'pq': np.mean(pq[in_category_set]), |
|
'sq': np.mean(sq[in_category_set]), |
|
'rq': np.mean(rq[in_category_set]), |
|
|
|
'n': np.sum(in_category_set.astype(np.int32)), |
|
} |
|
else: |
|
results[category_set_name] = {'pq': 0, 'sq': 0, 'rq': 0, 'n': 0} |
|
|
|
return results |
|
|
|
def result_per_category(self): |
|
"""See base class.""" |
|
sq = base_metric.realdiv_maybe_zero(self.iou_per_class, self.tp_per_class) |
|
rq = base_metric.realdiv_maybe_zero( |
|
self.tp_per_class, |
|
self.tp_per_class + 0.5 * self.fn_per_class + 0.5 * self.fp_per_class) |
|
return np.multiply(sq, rq) |
|
|
|
def print_detailed_results(self, is_thing=None, print_digits=3): |
|
"""See base class.""" |
|
results = self.detailed_results(is_thing=is_thing) |
|
|
|
tab = prettytable.PrettyTable() |
|
|
|
tab.add_column('', [], align='l') |
|
for fieldname in ['PQ', 'SQ', 'RQ', 'N']: |
|
tab.add_column(fieldname, [], align='r') |
|
|
|
for category_set, subset_results in six.iteritems(results): |
|
data_cols = [ |
|
round(subset_results[col_key], print_digits) * 100 |
|
for col_key in ['pq', 'sq', 'rq'] |
|
] |
|
data_cols += [subset_results['n']] |
|
tab.add_row([category_set] + data_cols) |
|
|
|
print(tab) |
|
|
|
def result(self): |
|
"""See base class.""" |
|
pq_per_class = self.result_per_category() |
|
valid_categories = self._valid_categories() |
|
if not np.any(valid_categories): |
|
return 0. |
|
return np.mean(pq_per_class[valid_categories]) |
|
|
|
def merge(self, other_instance): |
|
"""See base class.""" |
|
self.iou_per_class += other_instance.iou_per_class |
|
self.tp_per_class += other_instance.tp_per_class |
|
self.fn_per_class += other_instance.fn_per_class |
|
self.fp_per_class += other_instance.fp_per_class |
|
|
|
def reset(self): |
|
"""See base class.""" |
|
self.iou_per_class = np.zeros(self.num_categories, dtype=np.float64) |
|
self.tp_per_class = np.zeros(self.num_categories, dtype=np.float64) |
|
self.fn_per_class = np.zeros(self.num_categories, dtype=np.float64) |
|
self.fp_per_class = np.zeros(self.num_categories, dtype=np.float64) |
|
|