# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Box matcher implementation.""" from typing import List, Tuple import tensorflow as tf, tf_keras class BoxMatcher: """Matcher based on highest value. This class computes matches from a similarity matrix. Each column is matched to a single row. To support object detection target assignment this class enables setting both positive_threshold (upper threshold) and negative_threshold (lower thresholds) defining three categories of similarity which define whether examples are positive, negative, or ignored, for example: (1) thresholds=[negative_threshold, positive_threshold], and indicators=[negative_value, ignore_value, positive_value]: The similarity metrics below negative_threshold will be assigned with negative_value, the metrics between negative_threshold and positive_threshold will be assigned ignore_value, and the metrics above positive_threshold will be assigned positive_value. (2) thresholds=[negative_threshold, positive_threshold], and indicators=[ignore_value, negative_value, positive_value]: The similarity metric below negative_threshold will be assigned with ignore_value, the metrics between negative_threshold and positive_threshold will be assigned negative_value, and the metrics above positive_threshold will be assigned positive_value. """ def __init__(self, thresholds: List[float], indicators: List[int], force_match_for_each_col: bool = False): """Construct BoxMatcher. Args: thresholds: A list of thresholds to classify the matches into different types (e.g. positive or negative or ignored match). The list needs to be sorted, and will be prepended with -Inf and appended with +Inf. indicators: A list of values representing match types (e.g. positive or negative or ignored match). len(`indicators`) must equal to len(`thresholds`) + 1. force_match_for_each_col: If True, ensures that each column is matched to at least one row (which is not guaranteed otherwise if the positive_threshold is high). Defaults to False. If True, all force matched row will be assigned to `indicators[-1]`. Raises: ValueError: If `threshold` not sorted, or len(indicators) != len(threshold) + 1 """ if not all([lo <= hi for (lo, hi) in zip(thresholds[:-1], thresholds[1:])]): raise ValueError('`threshold` must be sorted, got {}'.format(thresholds)) self.indicators = indicators if len(indicators) != len(thresholds) + 1: raise ValueError('len(`indicators`) must be len(`thresholds`) + 1, got ' 'indicators {}, thresholds {}'.format( indicators, thresholds)) thresholds = thresholds[:] thresholds.insert(0, -float('inf')) thresholds.append(float('inf')) self.thresholds = thresholds self._force_match_for_each_col = force_match_for_each_col def __call__(self, similarity_matrix: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: """Tries to match each column of the similarity matrix to a row. Args: similarity_matrix: A float tensor of shape [num_rows, num_cols] or [batch_size, num_rows, num_cols] representing any similarity metric. Returns: matched_columns: An integer tensor of shape [num_rows] or [batch_size, num_rows] storing the index of the matched column for each row. match_indicators: An integer tensor of shape [num_rows] or [batch_size, num_rows] storing the match type indicator (e.g. positive or negative or ignored match). """ squeeze_result = False if len(similarity_matrix.shape) == 2: squeeze_result = True similarity_matrix = tf.expand_dims(similarity_matrix, axis=0) static_shape = similarity_matrix.shape.as_list() num_rows = static_shape[1] or tf.shape(similarity_matrix)[1] batch_size = static_shape[0] or tf.shape(similarity_matrix)[0] def _match_when_rows_are_empty(): """Performs matching when the rows of similarity matrix are empty. When the rows are empty, all detections are false positives. So we return a tensor of -1's to indicate that the rows do not match to any columns. Returns: matched_columns: An integer tensor of shape [num_rows] or [batch_size, num_rows] storing the index of the matched column for each row. match_indicators: An integer tensor of shape [num_rows] or [batch_size, num_rows] storing the match type indicator (e.g. positive or negative or ignored match). """ with tf.name_scope('empty_gt_boxes'): matched_columns = tf.zeros([batch_size, num_rows], dtype=tf.int32) match_indicators = -tf.ones([batch_size, num_rows], dtype=tf.int32) return matched_columns, match_indicators def _match_when_rows_are_non_empty(): """Performs matching when the rows of similarity matrix are non empty. Returns: matched_columns: An integer tensor of shape [num_rows] or [batch_size, num_rows] storing the index of the matched column for each row. match_indicators: An integer tensor of shape [num_rows] or [batch_size, num_rows] storing the match type indicator (e.g. positive or negative or ignored match). """ with tf.name_scope('non_empty_gt_boxes'): matched_columns = tf.argmax( similarity_matrix, axis=-1, output_type=tf.int32) # Get logical indices of ignored and unmatched columns as tf.int64 matched_vals = tf.reduce_max(similarity_matrix, axis=-1) match_indicators = tf.zeros([batch_size, num_rows], tf.int32) match_dtype = matched_vals.dtype for (ind, low, high) in zip(self.indicators, self.thresholds[:-1], self.thresholds[1:]): low_threshold = tf.cast(low, match_dtype) high_threshold = tf.cast(high, match_dtype) mask = tf.logical_and( tf.greater_equal(matched_vals, low_threshold), tf.less(matched_vals, high_threshold)) match_indicators = self._set_values_using_indicator( match_indicators, mask, ind) if self._force_match_for_each_col: # [batch_size, num_cols], for each column (groundtruth_box), find the # best matching row (anchor). matching_rows = tf.argmax( input=similarity_matrix, axis=1, output_type=tf.int32) # [batch_size, num_cols, num_rows], a transposed 0-1 mapping matrix M, # where M[j, i] = 1 means column j is matched to row i. column_to_row_match_mapping = tf.one_hot( matching_rows, depth=num_rows) # [batch_size, num_rows], for each row (anchor), find the matched # column (groundtruth_box). force_matched_columns = tf.argmax( input=column_to_row_match_mapping, axis=1, output_type=tf.int32) # [batch_size, num_rows] force_matched_column_mask = tf.cast( tf.reduce_max(column_to_row_match_mapping, axis=1), tf.bool) # [batch_size, num_rows] matched_columns = tf.where(force_matched_column_mask, force_matched_columns, matched_columns) match_indicators = tf.where( force_matched_column_mask, self.indicators[-1] * tf.ones([batch_size, num_rows], dtype=tf.int32), match_indicators) return matched_columns, match_indicators num_gt_boxes = similarity_matrix.shape.as_list()[-1] or tf.shape( similarity_matrix)[-1] matched_columns, match_indicators = tf.cond( pred=tf.greater(num_gt_boxes, 0), true_fn=_match_when_rows_are_non_empty, false_fn=_match_when_rows_are_empty) if squeeze_result: matched_columns = tf.squeeze(matched_columns, axis=0) match_indicators = tf.squeeze(match_indicators, axis=0) return matched_columns, match_indicators def _set_values_using_indicator(self, x, indicator, val): """Set the indicated fields of x to val. Args: x: tensor. indicator: boolean with same shape as x. val: scalar with value to set. Returns: modified tensor. """ indicator = tf.cast(indicator, x.dtype) return tf.add(tf.multiply(x, 1 - indicator), val * indicator)