NCTCMumbai's picture
Upload 2583 files
97b6013 verified
# Lint as: python2, python3
# Copyright 2019 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the Panoptic Quality metric.
Panoptic Quality is an instance-based metric for evaluating the task of
image parsing, aka panoptic segmentation.
Please see the paper for details:
"Panoptic Segmentation", Alexander Kirillov, Kaiming He, Ross Girshick,
Carsten Rother and Piotr Dollar. arXiv:1801.00868, 2018.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
import prettytable
import six
from deeplab.evaluation import base_metric
def _ids_to_counts(id_array):
"""Given a numpy array, a mapping from each unique entry to its count."""
ids, counts = np.unique(id_array, return_counts=True)
return dict(six.moves.zip(ids, counts))
class PanopticQuality(base_metric.SegmentationMetric):
"""Metric class for Panoptic Quality.
"Panoptic Segmentation" by Alexander Kirillov, Kaiming He, Ross Girshick,
Carsten Rother, Piotr Dollar.
https://arxiv.org/abs/1801.00868
"""
def compare_and_accumulate(
self, groundtruth_category_array, groundtruth_instance_array,
predicted_category_array, predicted_instance_array):
"""See base class."""
# First, combine the category and instance labels so that every unique
# value for (category, instance) is assigned a unique integer label.
pred_segment_id = self._naively_combine_labels(predicted_category_array,
predicted_instance_array)
gt_segment_id = self._naively_combine_labels(groundtruth_category_array,
groundtruth_instance_array)
# Pre-calculate areas for all groundtruth and predicted segments.
gt_segment_areas = _ids_to_counts(gt_segment_id)
pred_segment_areas = _ids_to_counts(pred_segment_id)
# We assume there is only one void segment and it has instance id = 0.
void_segment_id = self.ignored_label * self.max_instances_per_category
# There may be other ignored groundtruth segments with instance id > 0, find
# those ids using the unique segment ids extracted with the area computation
# above.
ignored_segment_ids = {
gt_segment_id for gt_segment_id in six.iterkeys(gt_segment_areas)
if (gt_segment_id //
self.max_instances_per_category) == self.ignored_label
}
# Next, combine the groundtruth and predicted labels. Dividing up the pixels
# based on which groundtruth segment and which predicted segment they belong
# to, this will assign a different 32-bit integer label to each choice
# of (groundtruth segment, predicted segment), encoded as
# gt_segment_id * offset + pred_segment_id.
intersection_id_array = (
gt_segment_id.astype(np.uint32) * self.offset +
pred_segment_id.astype(np.uint32))
# For every combination of (groundtruth segment, predicted segment) with a
# non-empty intersection, this counts the number of pixels in that
# intersection.
intersection_areas = _ids_to_counts(intersection_id_array)
# Helper function that computes the area of the overlap between a predicted
# segment and the ground-truth void/ignored segment.
def prediction_void_overlap(pred_segment_id):
void_intersection_id = void_segment_id * self.offset + pred_segment_id
return intersection_areas.get(void_intersection_id, 0)
# Compute overall ignored overlap.
def prediction_ignored_overlap(pred_segment_id):
total_ignored_overlap = 0
for ignored_segment_id in ignored_segment_ids:
intersection_id = ignored_segment_id * self.offset + pred_segment_id
total_ignored_overlap += intersection_areas.get(intersection_id, 0)
return total_ignored_overlap
# Sets that are populated with which segments groundtruth/predicted segments
# have been matched with overlapping predicted/groundtruth segments
# respectively.
gt_matched = set()
pred_matched = set()
# Calculate IoU per pair of intersecting segments of the same category.
for intersection_id, intersection_area in six.iteritems(intersection_areas):
gt_segment_id = intersection_id // self.offset
pred_segment_id = intersection_id % self.offset
gt_category = gt_segment_id // self.max_instances_per_category
pred_category = pred_segment_id // self.max_instances_per_category
if gt_category != pred_category:
continue
# Union between the groundtruth and predicted segments being compared does
# not include the portion of the predicted segment that consists of
# groundtruth "void" pixels.
union = (
gt_segment_areas[gt_segment_id] +
pred_segment_areas[pred_segment_id] - intersection_area -
prediction_void_overlap(pred_segment_id))
iou = intersection_area / union
if iou > 0.5:
self.tp_per_class[gt_category] += 1
self.iou_per_class[gt_category] += iou
gt_matched.add(gt_segment_id)
pred_matched.add(pred_segment_id)
# Count false negatives for each category.
for gt_segment_id in six.iterkeys(gt_segment_areas):
if gt_segment_id in gt_matched:
continue
category = gt_segment_id // self.max_instances_per_category
# Failing to detect a void segment is not a false negative.
if category == self.ignored_label:
continue
self.fn_per_class[category] += 1
# Count false positives for each category.
for pred_segment_id in six.iterkeys(pred_segment_areas):
if pred_segment_id in pred_matched:
continue
# A false positive is not penalized if is mostly ignored in the
# groundtruth.
if (prediction_ignored_overlap(pred_segment_id) /
pred_segment_areas[pred_segment_id]) > 0.5:
continue
category = pred_segment_id // self.max_instances_per_category
self.fp_per_class[category] += 1
return self.result()
def _valid_categories(self):
"""Categories with a "valid" value for the metric, have > 0 instances.
We will ignore the `ignore_label` class and other classes which have
`tp + fn + fp = 0`.
Returns:
Boolean array of shape `[num_categories]`.
"""
valid_categories = np.not_equal(
self.tp_per_class + self.fn_per_class + self.fp_per_class, 0)
if self.ignored_label >= 0 and self.ignored_label < self.num_categories:
valid_categories[self.ignored_label] = False
return valid_categories
def detailed_results(self, is_thing=None):
"""See base class."""
valid_categories = self._valid_categories()
# If known, break down which categories are valid _and_ things/stuff.
category_sets = collections.OrderedDict()
category_sets['All'] = valid_categories
if is_thing is not None:
category_sets['Things'] = np.logical_and(valid_categories, is_thing)
category_sets['Stuff'] = np.logical_and(valid_categories,
np.logical_not(is_thing))
# Compute individual per-class metrics that constitute factors of PQ.
sq = base_metric.realdiv_maybe_zero(self.iou_per_class, self.tp_per_class)
rq = base_metric.realdiv_maybe_zero(
self.tp_per_class,
self.tp_per_class + 0.5 * self.fn_per_class + 0.5 * self.fp_per_class)
pq = np.multiply(sq, rq)
# Assemble detailed results dictionary.
results = {}
for category_set_name, in_category_set in six.iteritems(category_sets):
if np.any(in_category_set):
results[category_set_name] = {
'pq': np.mean(pq[in_category_set]),
'sq': np.mean(sq[in_category_set]),
'rq': np.mean(rq[in_category_set]),
# The number of categories in this subset.
'n': np.sum(in_category_set.astype(np.int32)),
}
else:
results[category_set_name] = {'pq': 0, 'sq': 0, 'rq': 0, 'n': 0}
return results
def result_per_category(self):
"""See base class."""
sq = base_metric.realdiv_maybe_zero(self.iou_per_class, self.tp_per_class)
rq = base_metric.realdiv_maybe_zero(
self.tp_per_class,
self.tp_per_class + 0.5 * self.fn_per_class + 0.5 * self.fp_per_class)
return np.multiply(sq, rq)
def print_detailed_results(self, is_thing=None, print_digits=3):
"""See base class."""
results = self.detailed_results(is_thing=is_thing)
tab = prettytable.PrettyTable()
tab.add_column('', [], align='l')
for fieldname in ['PQ', 'SQ', 'RQ', 'N']:
tab.add_column(fieldname, [], align='r')
for category_set, subset_results in six.iteritems(results):
data_cols = [
round(subset_results[col_key], print_digits) * 100
for col_key in ['pq', 'sq', 'rq']
]
data_cols += [subset_results['n']]
tab.add_row([category_set] + data_cols)
print(tab)
def result(self):
"""See base class."""
pq_per_class = self.result_per_category()
valid_categories = self._valid_categories()
if not np.any(valid_categories):
return 0.
return np.mean(pq_per_class[valid_categories])
def merge(self, other_instance):
"""See base class."""
self.iou_per_class += other_instance.iou_per_class
self.tp_per_class += other_instance.tp_per_class
self.fn_per_class += other_instance.fn_per_class
self.fp_per_class += other_instance.fp_per_class
def reset(self):
"""See base class."""
self.iou_per_class = np.zeros(self.num_categories, dtype=np.float64)
self.tp_per_class = np.zeros(self.num_categories, dtype=np.float64)
self.fn_per_class = np.zeros(self.num_categories, dtype=np.float64)
self.fp_per_class = np.zeros(self.num_categories, dtype=np.float64)