File size: 9,822 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contains the definitions of Feature Pyramid Networks (FPN)."""
from typing import Any, Mapping, Optional

# Import libraries
from absl import logging
import tensorflow as tf, tf_keras

from official.modeling import hyperparams
from official.modeling import tf_utils
from official.vision.modeling.decoders import factory
from official.vision.ops import spatial_transform_ops


@tf_keras.utils.register_keras_serializable(package='Vision')
class FPN(tf_keras.Model):
  """Creates a Feature Pyramid Network (FPN).

  This implements the paper:
  Tsung-Yi Lin, Piotr Dollar, Ross Girshick, Kaiming He, Bharath Hariharan, and
  Serge Belongie.
  Feature Pyramid Networks for Object Detection.
  (https://arxiv.org/pdf/1612.03144)
  """

  def __init__(
      self,
      input_specs: Mapping[str, tf.TensorShape],
      min_level: int = 3,
      max_level: int = 7,
      num_filters: int = 256,
      fusion_type: str = 'sum',
      use_separable_conv: bool = False,
      use_keras_layer: bool = False,
      activation: str = 'relu',
      use_sync_bn: bool = False,
      norm_momentum: float = 0.99,
      norm_epsilon: float = 0.001,
      kernel_initializer: str = 'VarianceScaling',
      kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
      bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
      **kwargs):
    """Initializes a Feature Pyramid Network (FPN).

    Args:
      input_specs: A `dict` of input specifications. A dictionary consists of
        {level: TensorShape} from a backbone.
      min_level: An `int` of minimum level in FPN output feature maps.
      max_level: An `int` of maximum level in FPN output feature maps.
      num_filters: An `int` number of filters in FPN layers.
      fusion_type: A `str` of `sum` or `concat`. Whether performing sum or
        concat for feature fusion.
      use_separable_conv: A `bool`.  If True use separable convolution for
        convolution in FPN layers.
      use_keras_layer: A `bool`. If Ture use keras layers as many as possible.
      activation: A `str` name of the activation function.
      use_sync_bn: A `bool`. If True, use synchronized batch normalization.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A `float` added to variance to avoid dividing by zero.
      kernel_initializer: A `str` name of kernel_initializer for convolutional
        layers.
      kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
        Conv2D. Default is None.
      bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2D.
      **kwargs: Additional keyword arguments to be passed.
    """
    self._config_dict = {
        'input_specs': input_specs,
        'min_level': min_level,
        'max_level': max_level,
        'num_filters': num_filters,
        'fusion_type': fusion_type,
        'use_separable_conv': use_separable_conv,
        'use_keras_layer': use_keras_layer,
        'activation': activation,
        'use_sync_bn': use_sync_bn,
        'norm_momentum': norm_momentum,
        'norm_epsilon': norm_epsilon,
        'kernel_initializer': kernel_initializer,
        'kernel_regularizer': kernel_regularizer,
        'bias_regularizer': bias_regularizer,
    }
    conv2d = (
        tf_keras.layers.SeparableConv2D
        if use_separable_conv
        else tf_keras.layers.Conv2D
    )
    norm = tf_keras.layers.BatchNormalization
    activation_fn = tf_utils.get_activation(activation, use_keras_layer=True)

    # Build input feature pyramid.
    bn_axis = (
        -1 if tf_keras.backend.image_data_format() == 'channels_last' else 1
    )

    # Get input feature pyramid from backbone.
    logging.info('FPN input_specs: %s', input_specs)
    inputs = self._build_input_pyramid(input_specs, min_level)
    backbone_max_level = min(int(max(inputs.keys())), max_level)

    # Build lateral connections.
    feats_lateral = {}
    for level in range(min_level, backbone_max_level + 1):
      feats_lateral[str(level)] = conv2d(
          filters=num_filters,
          kernel_size=1,
          padding='same',
          kernel_initializer=kernel_initializer,
          kernel_regularizer=kernel_regularizer,
          bias_regularizer=bias_regularizer,
          name=f'lateral_{level}')(
              inputs[str(level)])

    # Build top-down path.
    feats = {str(backbone_max_level): feats_lateral[str(backbone_max_level)]}
    for level in range(backbone_max_level - 1, min_level - 1, -1):
      feat_a = spatial_transform_ops.nearest_upsampling(
          feats[str(level + 1)], 2, use_keras_layer=use_keras_layer)
      feat_b = feats_lateral[str(level)]

      if fusion_type == 'sum':
        if use_keras_layer:
          feats[str(level)] = tf_keras.layers.Add()([feat_a, feat_b])
        else:
          feats[str(level)] = feat_a + feat_b
      elif fusion_type == 'concat':
        if use_keras_layer:
          feats[str(level)] = tf_keras.layers.Concatenate(axis=-1)(
              [feat_a, feat_b])
        else:
          feats[str(level)] = tf.concat([feat_a, feat_b], axis=-1)
      else:
        raise ValueError('Fusion type {} not supported.'.format(fusion_type))

    # TODO(fyangf): experiment with removing bias in conv2d.
    # Build post-hoc 3x3 convolution kernel.
    for level in range(min_level, backbone_max_level + 1):
      feats[str(level)] = conv2d(
          filters=num_filters,
          strides=1,
          kernel_size=3,
          padding='same',
          kernel_initializer=kernel_initializer,
          kernel_regularizer=kernel_regularizer,
          bias_regularizer=bias_regularizer,
          name=f'post_hoc_{level}')(
              feats[str(level)])

    # TODO(fyangf): experiment with removing bias in conv2d.
    # Build coarser FPN levels introduced for RetinaNet.
    for level in range(backbone_max_level + 1, max_level + 1):
      feats_in = feats[str(level - 1)]
      if level > backbone_max_level + 1:
        feats_in = activation_fn(feats_in)
      feats[str(level)] = conv2d(
          filters=num_filters,
          strides=2,
          kernel_size=3,
          padding='same',
          kernel_initializer=kernel_initializer,
          kernel_regularizer=kernel_regularizer,
          bias_regularizer=bias_regularizer,
          name=f'coarser_{level}')(
              feats_in)

    # Apply batch norm layers.
    for level in range(min_level, max_level + 1):
      feats[str(level)] = norm(
          axis=bn_axis,
          momentum=norm_momentum,
          epsilon=norm_epsilon,
          synchronized=use_sync_bn,
          name=f'norm_{level}')(
              feats[str(level)])

    self._output_specs = {
        str(level): feats[str(level)].get_shape()
        for level in range(min_level, max_level + 1)
    }

    super().__init__(inputs=inputs, outputs=feats, **kwargs)

  def _build_input_pyramid(self, input_specs: Mapping[str, tf.TensorShape],
                           min_level: int):
    assert isinstance(input_specs, dict)
    if min(input_specs.keys()) > str(min_level):
      raise ValueError(
          'Backbone min level should be less or equal to FPN min level')

    inputs = {}
    for level, spec in input_specs.items():
      inputs[level] = tf_keras.Input(shape=spec[1:])
    return inputs

  def get_config(self) -> Mapping[str, Any]:
    return self._config_dict

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)

  @property
  def output_specs(self) -> Mapping[str, tf.TensorShape]:
    """A dict of {level: TensorShape} pairs for the model output."""
    return self._output_specs


@factory.register_decoder_builder('fpn')
def build_fpn_decoder(
    input_specs: Mapping[str, tf.TensorShape],
    model_config: hyperparams.Config,
    l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None
) -> tf_keras.Model:
  """Builds FPN decoder from a config.

  Args:
    input_specs: A `dict` of input specifications. A dictionary consists of
      {level: TensorShape} from a backbone.
    model_config: A OneOfConfig. Model config.
    l2_regularizer: A `tf_keras.regularizers.Regularizer` instance. Default to
      None.

  Returns:
    A `tf_keras.Model` instance of the FPN decoder.

  Raises:
    ValueError: If the model_config.decoder.type is not `fpn`.
  """
  decoder_type = model_config.decoder.type
  decoder_cfg = model_config.decoder.get()
  if decoder_type != 'fpn':
    raise ValueError(f'Inconsistent decoder type {decoder_type}. '
                     'Need to be `fpn`.')
  norm_activation_config = model_config.norm_activation
  return FPN(
      input_specs=input_specs,
      min_level=model_config.min_level,
      max_level=model_config.max_level,
      num_filters=decoder_cfg.num_filters,
      fusion_type=decoder_cfg.fusion_type,
      use_separable_conv=decoder_cfg.use_separable_conv,
      use_keras_layer=decoder_cfg.use_keras_layer,
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer)