File size: 10,006 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contains definitions of ROI sampler."""
from typing import Optional, Tuple, Union
# Import libraries
import tensorflow as tf, tf_keras

from official.vision.modeling.layers import box_sampler
from official.vision.ops import box_matcher
from official.vision.ops import iou_similarity
from official.vision.ops import target_gather

# The return type can be a tuple of 4 or 5 tf.Tensor.
ROISamplerReturnType = Union[
    Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor],
    Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]]


@tf_keras.utils.register_keras_serializable(package='Vision')
class ROISampler(tf_keras.layers.Layer):
  """Samples ROIs and assigns targets to the sampled ROIs."""

  def __init__(self,
               mix_gt_boxes: bool = True,
               num_sampled_rois: int = 512,
               foreground_fraction: float = 0.25,
               foreground_iou_threshold: float = 0.5,
               background_iou_high_threshold: float = 0.5,
               background_iou_low_threshold: float = 0,
               skip_subsampling: bool = False,
               **kwargs):
    """Initializes a ROI sampler.

    Args:
      mix_gt_boxes: A `bool` of whether to mix the groundtruth boxes with
        proposed ROIs.
      num_sampled_rois: An `int` of the number of sampled ROIs per image.
      foreground_fraction: A `float` in [0, 1], what percentage of proposed ROIs
        should be sampled from the foreground boxes.
      foreground_iou_threshold: A `float` that represents the IoU threshold for
        a box to be considered as positive (if >= `foreground_iou_threshold`).
      background_iou_high_threshold: A `float` that represents the IoU threshold
        for a box to be considered as negative (if overlap in
        [`background_iou_low_threshold`, `background_iou_high_threshold`]).
      background_iou_low_threshold: A `float` that represents the IoU threshold
        for a box to be considered as negative (if overlap in
        [`background_iou_low_threshold`, `background_iou_high_threshold`])
      skip_subsampling: a bool that determines if we want to skip the sampling
        procedure than balances the fg/bg classes. Used for upper frcnn layers
        in cascade RCNN.
      **kwargs: Additional keyword arguments passed to Layer.
    """
    self._config_dict = {
        'mix_gt_boxes': mix_gt_boxes,
        'num_sampled_rois': num_sampled_rois,
        'foreground_fraction': foreground_fraction,
        'foreground_iou_threshold': foreground_iou_threshold,
        'background_iou_high_threshold': background_iou_high_threshold,
        'background_iou_low_threshold': background_iou_low_threshold,
        'skip_subsampling': skip_subsampling,
    }

    self._sim_calc = iou_similarity.IouSimilarity()
    self._box_matcher = box_matcher.BoxMatcher(
        thresholds=[
            background_iou_low_threshold, background_iou_high_threshold,
            foreground_iou_threshold
        ],
        indicators=[-3, -1, -2, 1])
    self._target_gather = target_gather.TargetGather()

    self._sampler = box_sampler.BoxSampler(
        num_sampled_rois, foreground_fraction)
    super().__init__(**kwargs)

  def call(
      self,
      boxes: tf.Tensor,
      gt_boxes: tf.Tensor,
      gt_classes: tf.Tensor,
      gt_outer_boxes: Optional[tf.Tensor] = None) -> ROISamplerReturnType:
    """Assigns the proposals with groundtruth classes and performs subsmpling.

    Given `proposed_boxes`, `gt_boxes`, and `gt_classes`, the function uses the
    following algorithm to generate the final `num_samples_per_image` RoIs.
      1. Calculates the IoU between each proposal box and each gt_boxes.
      2. Assigns each proposed box with a groundtruth class and box by choosing
         the largest IoU overlap.
      3. Samples `num_samples_per_image` boxes from all proposed boxes, and
         returns box_targets, class_targets, and RoIs.

    Args:
      boxes: A `tf.Tensor` of shape of [batch_size, N, 4]. N is the number of
        proposals before groundtruth assignment. The last dimension is the
        box coordinates w.r.t. the scaled images in [ymin, xmin, ymax, xmax]
        format.
      gt_boxes: A `tf.Tensor` of shape of [batch_size, MAX_NUM_INSTANCES, 4].
        The coordinates of gt_boxes are in the pixel coordinates of the scaled
        image. This tensor might have padding of values -1 indicating the
        invalid box coordinates.
      gt_classes: A `tf.Tensor` with a shape of [batch_size, MAX_NUM_INSTANCES].
        This tensor might have paddings with values of -1 indicating the invalid
        classes.
      gt_outer_boxes: A `tf.Tensor` of shape of [batch_size, MAX_NUM_INSTANCES,
        4]. The corrdinates of gt_outer_boxes are in the pixel coordinates of
        the scaled image. This tensor might have padding of values -1 indicating
        the invalid box coordinates. Ignored if not provided.

    Returns:
      sampled_rois: A `tf.Tensor` of shape of [batch_size, K, 4], representing
        the coordinates of the sampled RoIs, where K is the number of the
        sampled RoIs, i.e. K = num_samples_per_image.
      sampled_gt_boxes: A `tf.Tensor` of shape of [batch_size, K, 4], storing
        the box coordinates of the matched groundtruth boxes of the samples
        RoIs.
      sampled_gt_outer_boxes: A `tf.Tensor` of shape of [batch_size, K, 4],
        storing the box coordinates of the matched groundtruth outer boxes of
        the samples RoIs. This field is missing if gt_outer_boxes is None.
      sampled_gt_classes: A `tf.Tensor` of shape of [batch_size, K], storing the
        classes of the matched groundtruth boxes of the sampled RoIs.
      sampled_gt_indices: A `tf.Tensor` of shape of [batch_size, K], storing the
        indices of the sampled groudntruth boxes in the original `gt_boxes`
        tensor, i.e.,
        gt_boxes[sampled_gt_indices[:, i]] = sampled_gt_boxes[:, i].
    """
    gt_boxes = tf.cast(gt_boxes, dtype=boxes.dtype)
    if self._config_dict['mix_gt_boxes']:
      boxes = tf.concat([boxes, gt_boxes], axis=1)

    boxes_invalid_mask = tf.less(
        tf.reduce_max(boxes, axis=-1, keepdims=True), 0.0)
    gt_invalid_mask = tf.less(
        tf.reduce_max(gt_boxes, axis=-1, keepdims=True), 0.0)
    similarity_matrix = self._sim_calc(boxes, gt_boxes, boxes_invalid_mask,
                                       gt_invalid_mask)
    matched_gt_indices, match_indicators = self._box_matcher(similarity_matrix)
    positive_matches = tf.greater_equal(match_indicators, 0)
    negative_matches = tf.equal(match_indicators, -1)
    ignored_matches = tf.equal(match_indicators, -2)
    invalid_matches = tf.equal(match_indicators, -3)

    background_mask = tf.expand_dims(
        tf.logical_or(negative_matches, invalid_matches), -1)
    gt_classes = tf.expand_dims(gt_classes, axis=-1)
    matched_gt_classes = self._target_gather(gt_classes, matched_gt_indices,
                                             background_mask)
    matched_gt_classes = tf.where(background_mask,
                                  tf.zeros_like(matched_gt_classes),
                                  matched_gt_classes)
    matched_gt_boxes = self._target_gather(gt_boxes, matched_gt_indices,
                                           tf.tile(background_mask, [1, 1, 4]))
    matched_gt_boxes = tf.where(background_mask,
                                tf.zeros_like(matched_gt_boxes),
                                matched_gt_boxes)
    if gt_outer_boxes is not None:
      matched_gt_outer_boxes = self._target_gather(
          gt_outer_boxes, matched_gt_indices, tf.tile(background_mask,
                                                      [1, 1, 4]))
      matched_gt_outer_boxes = tf.where(background_mask,
                                        tf.zeros_like(matched_gt_outer_boxes),
                                        matched_gt_outer_boxes)
    matched_gt_indices = tf.where(
        tf.squeeze(background_mask, -1), -tf.ones_like(matched_gt_indices),
        matched_gt_indices)

    if self._config_dict['skip_subsampling']:
      matched_gt_classes = tf.squeeze(matched_gt_classes, axis=-1)
      if gt_outer_boxes is None:
        return (boxes, matched_gt_boxes, matched_gt_classes, matched_gt_indices)
      return (boxes, matched_gt_boxes, matched_gt_outer_boxes,
              matched_gt_classes, matched_gt_indices)

    sampled_indices = self._sampler(
        positive_matches, negative_matches, ignored_matches)

    sampled_rois = self._target_gather(boxes, sampled_indices)
    sampled_gt_boxes = self._target_gather(matched_gt_boxes, sampled_indices)
    sampled_gt_classes = tf.squeeze(self._target_gather(
        matched_gt_classes, sampled_indices), axis=-1)
    sampled_gt_indices = tf.squeeze(self._target_gather(
        tf.expand_dims(matched_gt_indices, -1), sampled_indices), axis=-1)
    if gt_outer_boxes is None:
      return (sampled_rois, sampled_gt_boxes, sampled_gt_classes,
              sampled_gt_indices)
    sampled_gt_outer_boxes = self._target_gather(matched_gt_outer_boxes,
                                                 sampled_indices)
    return (sampled_rois, sampled_gt_boxes, sampled_gt_outer_boxes,
            sampled_gt_classes, sampled_gt_indices)

  def get_config(self):
    return self._config_dict

  @classmethod
  def from_config(cls, config):
    return cls(**config)