File size: 3,411 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contains definitions of box sampler."""

# Import libraries
import tensorflow as tf, tf_keras

from official.vision.ops import sampling_ops


@tf_keras.utils.register_keras_serializable(package='Vision')
class BoxSampler(tf_keras.layers.Layer):
  """Creates a BoxSampler to sample positive and negative boxes."""

  def __init__(self,
               num_samples: int = 512,
               foreground_fraction: float = 0.25,
               **kwargs):
    """Initializes a box sampler.

    Args:
      num_samples: An `int` of the number of sampled boxes per image.
      foreground_fraction: A `float` in [0, 1], what percentage of boxes should
        be sampled from the positive examples.
      **kwargs: Additional keyword arguments passed to Layer.
    """
    self._config_dict = {
        'num_samples': num_samples,
        'foreground_fraction': foreground_fraction,
    }
    super(BoxSampler, self).__init__(**kwargs)

  def call(self, positive_matches: tf.Tensor, negative_matches: tf.Tensor,
           ignored_matches: tf.Tensor):
    """Samples and selects positive and negative instances.

    Args:
      positive_matches: A `bool` tensor of shape of [batch, N] where N is the
        number of instances. For each element, `True` means the instance
        corresponds to a positive example.
      negative_matches: A `bool` tensor of shape of [batch, N] where N is the
        number of instances. For each element, `True` means the instance
        corresponds to a negative example.
      ignored_matches: A `bool` tensor of shape of [batch, N] where N is the
        number of instances. For each element, `True` means the instance should
        be ignored.

    Returns:
      A `tf.tensor` of shape of [batch_size, K], storing the indices of the
        sampled examples, where K is `num_samples`.
    """
    sample_candidates = tf.logical_and(
        tf.logical_or(positive_matches, negative_matches),
        tf.logical_not(ignored_matches))

    sampler = sampling_ops.BalancedPositiveNegativeSampler(
        positive_fraction=self._config_dict['foreground_fraction'],
        is_static=True)

    batch_size = sample_candidates.shape[0]
    sampled_indicators = []
    for i in range(batch_size):
      sampled_indicator = sampler.subsample(
          sample_candidates[i],
          self._config_dict['num_samples'],
          positive_matches[i])
      sampled_indicators.append(sampled_indicator)
    sampled_indicators = tf.stack(sampled_indicators)
    _, selected_indices = tf.nn.top_k(
        tf.cast(sampled_indicators, dtype=tf.int32),
        k=self._config_dict['num_samples'],
        sorted=True)

    return selected_indices

  def get_config(self):
    return self._config_dict

  @classmethod
  def from_config(cls, config):
    return cls(**config)