deanna-emery's picture
updates
93528c6
raw
history blame
5.85 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Region Similarity Calculators."""
import tensorflow as tf, tf_keras
def area(box):
"""Computes area of boxes.
B: batch_size
N: number of boxes
Args:
box: a float Tensor with [N, 4], or [B, N, 4].
Returns:
a float Tensor with [N], or [B, N]
"""
with tf.name_scope('Area'):
y_min, x_min, y_max, x_max = tf.split(
value=box, num_or_size_splits=4, axis=-1)
return tf.squeeze((y_max - y_min) * (x_max - x_min), axis=-1)
def intersection(gt_boxes, boxes):
"""Compute pairwise intersection areas between boxes.
B: batch_size
N: number of groundtruth boxes.
M: number of anchor boxes.
Args:
gt_boxes: a float Tensor with [N, 4], or [B, N, 4]
boxes: a float Tensor with [M, 4], or [B, M, 4]
Returns:
a float Tensor with shape [N, M] or [B, N, M] representing pairwise
intersections.
"""
with tf.name_scope('Intersection'):
y_min1, x_min1, y_max1, x_max1 = tf.split(
value=gt_boxes, num_or_size_splits=4, axis=-1)
y_min2, x_min2, y_max2, x_max2 = tf.split(
value=boxes, num_or_size_splits=4, axis=-1)
boxes_rank = len(boxes.shape)
perm = [1, 0] if boxes_rank == 2 else [0, 2, 1]
# [N, M] or [B, N, M]
y_min_max = tf.minimum(y_max1, tf.transpose(y_max2, perm))
y_max_min = tf.maximum(y_min1, tf.transpose(y_min2, perm))
x_min_max = tf.minimum(x_max1, tf.transpose(x_max2, perm))
x_max_min = tf.maximum(x_min1, tf.transpose(x_min2, perm))
intersect_heights = y_min_max - y_max_min
intersect_widths = x_min_max - x_max_min
zeros_t = tf.cast(0, intersect_heights.dtype)
intersect_heights = tf.maximum(zeros_t, intersect_heights)
intersect_widths = tf.maximum(zeros_t, intersect_widths)
return intersect_heights * intersect_widths
def iou(gt_boxes, boxes):
"""Computes pairwise intersection-over-union between box collections.
Args:
gt_boxes: a float Tensor with [N, 4].
boxes: a float Tensor with [M, 4].
Returns:
a Tensor with shape [N, M] representing pairwise iou scores.
"""
with tf.name_scope('IOU'):
intersections = intersection(gt_boxes, boxes)
gt_boxes_areas = area(gt_boxes)
boxes_areas = area(boxes)
boxes_rank = len(boxes_areas.shape)
boxes_axis = 1 if (boxes_rank == 2) else 0
gt_boxes_areas = tf.expand_dims(gt_boxes_areas, -1)
boxes_areas = tf.expand_dims(boxes_areas, boxes_axis)
unions = gt_boxes_areas + boxes_areas
unions = unions - intersections
return tf.where(
tf.equal(intersections, 0.0), tf.zeros_like(intersections),
tf.truediv(intersections, unions))
class IouSimilarity:
"""Class to compute similarity based on Intersection over Union (IOU) metric.
"""
def __init__(self, mask_val=-1):
self.mask_val = mask_val
def __call__(self, boxes_1, boxes_2, boxes_1_masks=None, boxes_2_masks=None):
"""Compute pairwise IOU similarity between ground truth boxes and anchors.
B: batch_size
N: Number of groundtruth boxes.
M: Number of anchor boxes.
Args:
boxes_1: a float Tensor with M or B * M boxes.
boxes_2: a float Tensor with N or B * N boxes, the rank must be less than
or equal to rank of `boxes_1`.
boxes_1_masks: a boolean Tensor with M or B * M boxes. Optional.
boxes_2_masks: a boolean Tensor with N or B * N boxes. Optional.
Returns:
A Tensor with shape [M, N] or [B, M, N] representing pairwise
iou scores, anchor per row and groundtruth_box per colulmn.
Input shape:
boxes_1: [N, 4], or [B, N, 4]
boxes_2: [M, 4], or [B, M, 4]
boxes_1_masks: [N, 1], or [B, N, 1]
boxes_2_masks: [M, 1], or [B, M, 1]
Output shape:
[M, N], or [B, M, N]
"""
boxes_1 = tf.cast(boxes_1, tf.float32)
boxes_2 = tf.cast(boxes_2, tf.float32)
boxes_1_rank = len(boxes_1.shape)
boxes_2_rank = len(boxes_2.shape)
if boxes_1_rank < 2 or boxes_1_rank > 3:
raise ValueError(
'`groudtruth_boxes` must be rank 2 or 3, got {}'.format(boxes_1_rank))
if boxes_2_rank < 2 or boxes_2_rank > 3:
raise ValueError(
'`anchors` must be rank 2 or 3, got {}'.format(boxes_2_rank))
if boxes_1_rank < boxes_2_rank:
raise ValueError('`groundtruth_boxes` is unbatched while `anchors` is '
'batched is not a valid use case, got groundtruth_box '
'rank {}, and anchors rank {}'.format(
boxes_1_rank, boxes_2_rank))
result = iou(boxes_1, boxes_2)
if boxes_1_masks is None and boxes_2_masks is None:
return result
background_mask = None
mask_val_t = tf.cast(self.mask_val, result.dtype) * tf.ones_like(result)
perm = [1, 0] if boxes_2_rank == 2 else [0, 2, 1]
if boxes_1_masks is not None and boxes_2_masks is not None:
background_mask = tf.logical_or(boxes_1_masks,
tf.transpose(boxes_2_masks, perm))
elif boxes_1_masks is not None:
background_mask = boxes_1_masks
else:
background_mask = tf.logical_or(
tf.zeros(tf.shape(boxes_2)[:-1], dtype=tf.bool),
tf.transpose(boxes_2_masks, perm))
return tf.where(background_mask, mask_val_t, result)