# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains definitions of ROI sampler.""" from typing import Optional, Tuple, Union # Import libraries import tensorflow as tf, tf_keras from official.vision.modeling.layers import box_sampler from official.vision.ops import box_matcher from official.vision.ops import iou_similarity from official.vision.ops import target_gather # The return type can be a tuple of 4 or 5 tf.Tensor. ROISamplerReturnType = Union[ Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor], Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]] @tf_keras.utils.register_keras_serializable(package='Vision') class ROISampler(tf_keras.layers.Layer): """Samples ROIs and assigns targets to the sampled ROIs.""" def __init__(self, mix_gt_boxes: bool = True, num_sampled_rois: int = 512, foreground_fraction: float = 0.25, foreground_iou_threshold: float = 0.5, background_iou_high_threshold: float = 0.5, background_iou_low_threshold: float = 0, skip_subsampling: bool = False, **kwargs): """Initializes a ROI sampler. Args: mix_gt_boxes: A `bool` of whether to mix the groundtruth boxes with proposed ROIs. num_sampled_rois: An `int` of the number of sampled ROIs per image. foreground_fraction: A `float` in [0, 1], what percentage of proposed ROIs should be sampled from the foreground boxes. foreground_iou_threshold: A `float` that represents the IoU threshold for a box to be considered as positive (if >= `foreground_iou_threshold`). background_iou_high_threshold: A `float` that represents the IoU threshold for a box to be considered as negative (if overlap in [`background_iou_low_threshold`, `background_iou_high_threshold`]). background_iou_low_threshold: A `float` that represents the IoU threshold for a box to be considered as negative (if overlap in [`background_iou_low_threshold`, `background_iou_high_threshold`]) skip_subsampling: a bool that determines if we want to skip the sampling procedure than balances the fg/bg classes. Used for upper frcnn layers in cascade RCNN. **kwargs: Additional keyword arguments passed to Layer. """ self._config_dict = { 'mix_gt_boxes': mix_gt_boxes, 'num_sampled_rois': num_sampled_rois, 'foreground_fraction': foreground_fraction, 'foreground_iou_threshold': foreground_iou_threshold, 'background_iou_high_threshold': background_iou_high_threshold, 'background_iou_low_threshold': background_iou_low_threshold, 'skip_subsampling': skip_subsampling, } self._sim_calc = iou_similarity.IouSimilarity() self._box_matcher = box_matcher.BoxMatcher( thresholds=[ background_iou_low_threshold, background_iou_high_threshold, foreground_iou_threshold ], indicators=[-3, -1, -2, 1]) self._target_gather = target_gather.TargetGather() self._sampler = box_sampler.BoxSampler( num_sampled_rois, foreground_fraction) super().__init__(**kwargs) def call( self, boxes: tf.Tensor, gt_boxes: tf.Tensor, gt_classes: tf.Tensor, gt_outer_boxes: Optional[tf.Tensor] = None) -> ROISamplerReturnType: """Assigns the proposals with groundtruth classes and performs subsmpling. Given `proposed_boxes`, `gt_boxes`, and `gt_classes`, the function uses the following algorithm to generate the final `num_samples_per_image` RoIs. 1. Calculates the IoU between each proposal box and each gt_boxes. 2. Assigns each proposed box with a groundtruth class and box by choosing the largest IoU overlap. 3. Samples `num_samples_per_image` boxes from all proposed boxes, and returns box_targets, class_targets, and RoIs. Args: boxes: A `tf.Tensor` of shape of [batch_size, N, 4]. N is the number of proposals before groundtruth assignment. The last dimension is the box coordinates w.r.t. the scaled images in [ymin, xmin, ymax, xmax] format. gt_boxes: A `tf.Tensor` of shape of [batch_size, MAX_NUM_INSTANCES, 4]. The coordinates of gt_boxes are in the pixel coordinates of the scaled image. This tensor might have padding of values -1 indicating the invalid box coordinates. gt_classes: A `tf.Tensor` with a shape of [batch_size, MAX_NUM_INSTANCES]. This tensor might have paddings with values of -1 indicating the invalid classes. gt_outer_boxes: A `tf.Tensor` of shape of [batch_size, MAX_NUM_INSTANCES, 4]. The corrdinates of gt_outer_boxes are in the pixel coordinates of the scaled image. This tensor might have padding of values -1 indicating the invalid box coordinates. Ignored if not provided. Returns: sampled_rois: A `tf.Tensor` of shape of [batch_size, K, 4], representing the coordinates of the sampled RoIs, where K is the number of the sampled RoIs, i.e. K = num_samples_per_image. sampled_gt_boxes: A `tf.Tensor` of shape of [batch_size, K, 4], storing the box coordinates of the matched groundtruth boxes of the samples RoIs. sampled_gt_outer_boxes: A `tf.Tensor` of shape of [batch_size, K, 4], storing the box coordinates of the matched groundtruth outer boxes of the samples RoIs. This field is missing if gt_outer_boxes is None. sampled_gt_classes: A `tf.Tensor` of shape of [batch_size, K], storing the classes of the matched groundtruth boxes of the sampled RoIs. sampled_gt_indices: A `tf.Tensor` of shape of [batch_size, K], storing the indices of the sampled groudntruth boxes in the original `gt_boxes` tensor, i.e., gt_boxes[sampled_gt_indices[:, i]] = sampled_gt_boxes[:, i]. """ gt_boxes = tf.cast(gt_boxes, dtype=boxes.dtype) if self._config_dict['mix_gt_boxes']: boxes = tf.concat([boxes, gt_boxes], axis=1) boxes_invalid_mask = tf.less( tf.reduce_max(boxes, axis=-1, keepdims=True), 0.0) gt_invalid_mask = tf.less( tf.reduce_max(gt_boxes, axis=-1, keepdims=True), 0.0) similarity_matrix = self._sim_calc(boxes, gt_boxes, boxes_invalid_mask, gt_invalid_mask) matched_gt_indices, match_indicators = self._box_matcher(similarity_matrix) positive_matches = tf.greater_equal(match_indicators, 0) negative_matches = tf.equal(match_indicators, -1) ignored_matches = tf.equal(match_indicators, -2) invalid_matches = tf.equal(match_indicators, -3) background_mask = tf.expand_dims( tf.logical_or(negative_matches, invalid_matches), -1) gt_classes = tf.expand_dims(gt_classes, axis=-1) matched_gt_classes = self._target_gather(gt_classes, matched_gt_indices, background_mask) matched_gt_classes = tf.where(background_mask, tf.zeros_like(matched_gt_classes), matched_gt_classes) matched_gt_boxes = self._target_gather(gt_boxes, matched_gt_indices, tf.tile(background_mask, [1, 1, 4])) matched_gt_boxes = tf.where(background_mask, tf.zeros_like(matched_gt_boxes), matched_gt_boxes) if gt_outer_boxes is not None: matched_gt_outer_boxes = self._target_gather( gt_outer_boxes, matched_gt_indices, tf.tile(background_mask, [1, 1, 4])) matched_gt_outer_boxes = tf.where(background_mask, tf.zeros_like(matched_gt_outer_boxes), matched_gt_outer_boxes) matched_gt_indices = tf.where( tf.squeeze(background_mask, -1), -tf.ones_like(matched_gt_indices), matched_gt_indices) if self._config_dict['skip_subsampling']: matched_gt_classes = tf.squeeze(matched_gt_classes, axis=-1) if gt_outer_boxes is None: return (boxes, matched_gt_boxes, matched_gt_classes, matched_gt_indices) return (boxes, matched_gt_boxes, matched_gt_outer_boxes, matched_gt_classes, matched_gt_indices) sampled_indices = self._sampler( positive_matches, negative_matches, ignored_matches) sampled_rois = self._target_gather(boxes, sampled_indices) sampled_gt_boxes = self._target_gather(matched_gt_boxes, sampled_indices) sampled_gt_classes = tf.squeeze(self._target_gather( matched_gt_classes, sampled_indices), axis=-1) sampled_gt_indices = tf.squeeze(self._target_gather( tf.expand_dims(matched_gt_indices, -1), sampled_indices), axis=-1) if gt_outer_boxes is None: return (sampled_rois, sampled_gt_boxes, sampled_gt_classes, sampled_gt_indices) sampled_gt_outer_boxes = self._target_gather(matched_gt_outer_boxes, sampled_indices) return (sampled_rois, sampled_gt_boxes, sampled_gt_outer_boxes, sampled_gt_classes, sampled_gt_indices) def get_config(self): return self._config_dict @classmethod def from_config(cls, config): return cls(**config)