# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tensorflow implementation of non max suppression.""" # Import libraries import tensorflow as tf, tf_keras from official.vision.ops import box_ops NMS_TILE_SIZE = 512 def _self_suppression(iou, _, iou_sum): batch_size = tf.shape(iou)[0] can_suppress_others = tf.cast( tf.reshape(tf.reduce_max(iou, 1) <= 0.5, [batch_size, -1, 1]), iou.dtype) iou_suppressed = tf.reshape( tf.cast(tf.reduce_max(can_suppress_others * iou, 1) <= 0.5, iou.dtype), [batch_size, -1, 1]) * iou iou_sum_new = tf.reduce_sum(iou_suppressed, [1, 2]) return [ iou_suppressed, tf.reduce_any(iou_sum - iou_sum_new > 0.5), iou_sum_new ] def _cross_suppression(boxes, box_slice, iou_threshold, inner_idx): batch_size = tf.shape(boxes)[0] new_slice = tf.slice(boxes, [0, inner_idx * NMS_TILE_SIZE, 0], [batch_size, NMS_TILE_SIZE, 4]) iou = box_ops.bbox_overlap(new_slice, box_slice) ret_slice = tf.expand_dims( tf.cast(tf.reduce_all(iou < iou_threshold, [1]), box_slice.dtype), 2) * box_slice return boxes, ret_slice, iou_threshold, inner_idx + 1 def _suppression_loop_body(boxes, iou_threshold, output_size, idx): """Process boxes in the range [idx*NMS_TILE_SIZE, (idx+1)*NMS_TILE_SIZE). Args: boxes: a tensor with a shape of [batch_size, anchors, 4]. iou_threshold: a float representing the threshold for deciding whether boxes overlap too much with respect to IOU. output_size: an int32 tensor of size [batch_size]. Representing the number of selected boxes for each batch. idx: an integer scalar representing induction variable. Returns: boxes: updated boxes. iou_threshold: pass down iou_threshold to the next iteration. output_size: the updated output_size. idx: the updated induction variable. """ boxes_shape = tf.shape(boxes) num_tiles = boxes_shape[1] // NMS_TILE_SIZE batch_size = boxes_shape[0] # Iterates over tiles that can possibly suppress the current tile. box_slice = tf.slice(boxes, [0, idx * NMS_TILE_SIZE, 0], [batch_size, NMS_TILE_SIZE, 4]) _, box_slice, _, _ = tf.while_loop( lambda _boxes, _box_slice, _threshold, inner_idx: inner_idx < idx, _cross_suppression, [boxes, box_slice, iou_threshold, tf.constant(0)]) # Iterates over the current tile to compute self-suppression. iou = box_ops.bbox_overlap(box_slice, box_slice) mask = tf.expand_dims( tf.reshape(tf.range(NMS_TILE_SIZE), [1, -1]) > tf.reshape( tf.range(NMS_TILE_SIZE), [-1, 1]), 0) iou *= tf.cast(tf.logical_and(mask, iou >= iou_threshold), iou.dtype) suppressed_iou, _, _ = tf.while_loop( lambda _iou, loop_condition, _iou_sum: loop_condition, _self_suppression, [iou, tf.constant(True), tf.reduce_sum(iou, [1, 2])]) suppressed_box = tf.reduce_sum(suppressed_iou, 1) > 0 box_slice *= tf.expand_dims(1.0 - tf.cast(suppressed_box, box_slice.dtype), 2) # Uses box_slice to update the input boxes. mask = tf.reshape( tf.cast(tf.equal(tf.range(num_tiles), idx), boxes.dtype), [1, -1, 1, 1]) boxes = tf.tile(tf.expand_dims( box_slice, [1]), [1, num_tiles, 1, 1]) * mask + tf.reshape( boxes, [batch_size, num_tiles, NMS_TILE_SIZE, 4]) * (1 - mask) boxes = tf.reshape(boxes, boxes_shape) # Updates output_size. output_size += tf.reduce_sum( tf.cast(tf.reduce_any(box_slice > 0, [2]), tf.int32), [1]) return boxes, iou_threshold, output_size, idx + 1 def sorted_non_max_suppression_padded(scores, boxes, max_output_size, iou_threshold): """A wrapper that handles non-maximum suppression. Assumption: * The boxes are sorted by scores unless the box is a dot (all coordinates are zero). * Boxes with higher scores can be used to suppress boxes with lower scores. The overal design of the algorithm is to handle boxes tile-by-tile: boxes = boxes.pad_to_multiply_of(tile_size) num_tiles = len(boxes) // tile_size output_boxes = [] for i in range(num_tiles): box_tile = boxes[i*tile_size : (i+1)*tile_size] for j in range(i - 1): suppressing_tile = boxes[j*tile_size : (j+1)*tile_size] iou = bbox_overlap(box_tile, suppressing_tile) # if the box is suppressed in iou, clear it to a dot box_tile *= _update_boxes(iou) # Iteratively handle the diagnal tile. iou = _box_overlap(box_tile, box_tile) iou_changed = True while iou_changed: # boxes that are not suppressed by anything else suppressing_boxes = _get_suppressing_boxes(iou) # boxes that are suppressed by suppressing_boxes suppressed_boxes = _get_suppressed_boxes(iou, suppressing_boxes) # clear iou to 0 for boxes that are suppressed, as they cannot be used # to suppress other boxes any more new_iou = _clear_iou(iou, suppressed_boxes) iou_changed = (new_iou != iou) iou = new_iou # remaining boxes that can still suppress others, are selected boxes. output_boxes.append(_get_suppressing_boxes(iou)) if len(output_boxes) >= max_output_size: break Args: scores: a tensor with a shape of [batch_size, anchors]. boxes: a tensor with a shape of [batch_size, anchors, 4]. max_output_size: a scalar integer `Tensor` representing the maximum number of boxes to be selected by non max suppression. iou_threshold: a float representing the threshold for deciding whether boxes overlap too much with respect to IOU. Returns: nms_scores: a tensor with a shape of [batch_size, anchors]. It has same dtype as input scores. nms_proposals: a tensor with a shape of [batch_size, anchors, 4]. It has same dtype as input boxes. """ batch_size = tf.shape(boxes)[0] num_boxes = tf.shape(boxes)[1] pad = tf.cast( tf.math.ceil(tf.cast(num_boxes, tf.float32) / NMS_TILE_SIZE), tf.int32) * NMS_TILE_SIZE - num_boxes boxes = tf.pad(tf.cast(boxes, tf.float32), [[0, 0], [0, pad], [0, 0]]) scores = tf.pad( tf.cast(scores, tf.float32), [[0, 0], [0, pad]], constant_values=-1) num_boxes += pad def _loop_cond(unused_boxes, unused_threshold, output_size, idx): return tf.logical_and( tf.reduce_min(output_size) < max_output_size, idx < num_boxes // NMS_TILE_SIZE) selected_boxes, _, output_size, _ = tf.while_loop( _loop_cond, _suppression_loop_body, [ boxes, iou_threshold, tf.zeros([batch_size], tf.int32), tf.constant(0) ]) idx = num_boxes - tf.cast( tf.nn.top_k( tf.cast(tf.reduce_any(selected_boxes > 0, [2]), tf.int32) * tf.expand_dims(tf.range(num_boxes, 0, -1), 0), max_output_size)[0], tf.int32) idx = tf.minimum(idx, num_boxes - 1) idx = tf.reshape( idx + tf.reshape(tf.range(batch_size) * num_boxes, [-1, 1]), [-1]) boxes = tf.reshape( tf.gather(tf.reshape(boxes, [-1, 4]), idx), [batch_size, max_output_size, 4]) boxes = boxes * tf.cast( tf.reshape(tf.range(max_output_size), [1, -1, 1]) < tf.reshape( output_size, [-1, 1, 1]), boxes.dtype) scores = tf.reshape( tf.gather(tf.reshape(scores, [-1, 1]), idx), [batch_size, max_output_size]) scores = scores * tf.cast( tf.reshape(tf.range(max_output_size), [1, -1]) < tf.reshape( output_size, [-1, 1]), scores.dtype) return scores, boxes