File size: 18,657 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93528c6
 
 
 
 
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contains definitions of 3D Residual Networks."""
from typing import Callable, List, Tuple, Optional

# Import libraries
import tensorflow as tf, tf_keras

from official.modeling import hyperparams
from official.modeling import tf_utils
from official.vision.modeling.backbones import factory
from official.vision.modeling.layers import nn_blocks_3d
from official.vision.modeling.layers import nn_layers

layers = tf_keras.layers

RESNET_SPECS = {
    50: [
        ('bottleneck3d', 64, 3),
        ('bottleneck3d', 128, 4),
        ('bottleneck3d', 256, 6),
        ('bottleneck3d', 512, 3),
    ],
    101: [
        ('bottleneck3d', 64, 3),
        ('bottleneck3d', 128, 4),
        ('bottleneck3d', 256, 23),
        ('bottleneck3d', 512, 3),
    ],
    152: [
        ('bottleneck3d', 64, 3),
        ('bottleneck3d', 128, 8),
        ('bottleneck3d', 256, 36),
        ('bottleneck3d', 512, 3),
    ],
    200: [
        ('bottleneck3d', 64, 3),
        ('bottleneck3d', 128, 24),
        ('bottleneck3d', 256, 36),
        ('bottleneck3d', 512, 3),
    ],
    270: [
        ('bottleneck3d', 64, 4),
        ('bottleneck3d', 128, 29),
        ('bottleneck3d', 256, 53),
        ('bottleneck3d', 512, 4),
    ],
    300: [
        ('bottleneck3d', 64, 4),
        ('bottleneck3d', 128, 36),
        ('bottleneck3d', 256, 54),
        ('bottleneck3d', 512, 4),
    ],
    350: [
        ('bottleneck3d', 64, 4),
        ('bottleneck3d', 128, 36),
        ('bottleneck3d', 256, 72),
        ('bottleneck3d', 512, 4),
    ],
}


@tf_keras.utils.register_keras_serializable(package='Vision')
class ResNet3D(tf_keras.Model):
  """Creates a 3D ResNet family model."""

  def __init__(
      self,
      model_id: int,
      temporal_strides: List[int],
      temporal_kernel_sizes: List[Tuple[int]],
      use_self_gating: Optional[List[int]] = None,
      input_specs: tf_keras.layers.InputSpec = layers.InputSpec(
          shape=[None, None, None, None, 3]),
      stem_type: str = 'v0',
      stem_conv_temporal_kernel_size: int = 5,
      stem_conv_temporal_stride: int = 2,
      stem_pool_temporal_stride: int = 2,
      init_stochastic_depth_rate: float = 0.0,
      activation: str = 'relu',
      se_ratio: Optional[float] = None,
      use_sync_bn: bool = False,
      norm_momentum: float = 0.99,
      norm_epsilon: float = 0.001,
      kernel_initializer: str = 'VarianceScaling',
      kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
      bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
      **kwargs):
    """Initializes a 3D ResNet model.

    Args:
      model_id: An `int` of depth of ResNet backbone model.
      temporal_strides: A list of integers that specifies the temporal strides
        for all 3d blocks.
      temporal_kernel_sizes: A list of tuples that specifies the temporal kernel
        sizes for all 3d blocks in different block groups.
      use_self_gating: A list of booleans to specify applying self-gating module
        or not in each block group. If None, self-gating is not applied.
      input_specs: A `tf_keras.layers.InputSpec` of the input tensor.
      stem_type: A `str` of stem type of ResNet. Default to `v0`. If set to
        `v1`, use ResNet-D type stem (https://arxiv.org/abs/1812.01187).
      stem_conv_temporal_kernel_size: An `int` of temporal kernel size for the
        first conv layer.
      stem_conv_temporal_stride: An `int` of temporal stride for the first conv
        layer.
      stem_pool_temporal_stride: An `int` of temporal stride for the first pool
        layer.
      init_stochastic_depth_rate: A `float` of initial stochastic depth rate.
      activation: A `str` of name of the activation function.
      se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer.
      use_sync_bn: If True, use synchronized batch normalization.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A `float` added to variance to avoid dividing by zero.
      kernel_initializer: A str for kernel initializer of convolutional layers.
      kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
        Conv2D. Default to None.
      bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2D.
        Default to None.
      **kwargs: Additional keyword arguments to be passed.
    """
    self._model_id = model_id
    self._temporal_strides = temporal_strides
    self._temporal_kernel_sizes = temporal_kernel_sizes
    self._input_specs = input_specs
    self._stem_type = stem_type
    self._stem_conv_temporal_kernel_size = stem_conv_temporal_kernel_size
    self._stem_conv_temporal_stride = stem_conv_temporal_stride
    self._stem_pool_temporal_stride = stem_pool_temporal_stride
    self._use_self_gating = use_self_gating
    self._se_ratio = se_ratio
    self._init_stochastic_depth_rate = init_stochastic_depth_rate
    self._use_sync_bn = use_sync_bn
    self._activation = activation
    self._norm_momentum = norm_momentum
    self._norm_epsilon = norm_epsilon
    self._norm = layers.BatchNormalization
    self._kernel_initializer = kernel_initializer
    self._kernel_regularizer = kernel_regularizer
    self._bias_regularizer = bias_regularizer
    if tf_keras.backend.image_data_format() == 'channels_last':
      self._bn_axis = -1
    else:
      self._bn_axis = 1

    # Build ResNet3D backbone.
    inputs = tf_keras.Input(shape=input_specs.shape[1:])
    endpoints = self._build_model(inputs)
    self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}

    super(ResNet3D, self).__init__(inputs=inputs, outputs=endpoints, **kwargs)

  def _build_model(self, inputs):
    """Builds model architecture.

    Args:
      inputs: the keras input spec.

    Returns:
      endpoints: A dictionary of backbone endpoint features.
    """
    # Build stem.
    x = self._build_stem(inputs, stem_type=self._stem_type)

    temporal_kernel_size = 1 if self._stem_pool_temporal_stride == 1 else 3
    x = layers.MaxPool3D(
        pool_size=[temporal_kernel_size, 3, 3],
        strides=[self._stem_pool_temporal_stride, 2, 2],
        padding='same')(x)

    # Build intermediate blocks and endpoints.
    resnet_specs = RESNET_SPECS[self._model_id]
    if len(self._temporal_strides) != len(resnet_specs) or len(
        self._temporal_kernel_sizes) != len(resnet_specs):
      raise ValueError(
          'Number of blocks in temporal specs should equal to resnet_specs.')

    endpoints = {}
    for i, resnet_spec in enumerate(resnet_specs):
      if resnet_spec[0] == 'bottleneck3d':
        block_fn = nn_blocks_3d.BottleneckBlock3D
      else:
        raise ValueError('Block fn `{}` is not supported.'.format(
            resnet_spec[0]))

      use_self_gating = (
          self._use_self_gating[i] if self._use_self_gating else False)
      x = self._block_group(
          inputs=x,
          filters=resnet_spec[1],
          temporal_kernel_sizes=self._temporal_kernel_sizes[i],
          temporal_strides=self._temporal_strides[i],
          spatial_strides=(1 if i == 0 else 2),
          block_fn=block_fn,
          block_repeats=resnet_spec[2],
          stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
              self._init_stochastic_depth_rate, i + 2, 5),
          use_self_gating=use_self_gating,
          name='block_group_l{}'.format(i + 2))
      endpoints[str(i + 2)] = x

    return endpoints

  def _build_stem(self, inputs, stem_type):
    """Builds stem layer."""
    # Build stem.
    if stem_type == 'v0':
      x = layers.Conv3D(
          filters=64,
          kernel_size=[self._stem_conv_temporal_kernel_size, 7, 7],
          strides=[self._stem_conv_temporal_stride, 2, 2],
          use_bias=False,
          padding='same',
          kernel_initializer=self._kernel_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer)(
              inputs)
      x = self._norm(
          axis=self._bn_axis,
          momentum=self._norm_momentum,
          epsilon=self._norm_epsilon,
          synchronized=self._use_sync_bn)(x)
      x = tf_utils.get_activation(self._activation)(x)
    elif stem_type == 'v1':
      x = layers.Conv3D(
          filters=32,
          kernel_size=[self._stem_conv_temporal_kernel_size, 3, 3],
          strides=[self._stem_conv_temporal_stride, 2, 2],
          use_bias=False,
          padding='same',
          kernel_initializer=self._kernel_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer)(
              inputs)
      x = self._norm(
          axis=self._bn_axis,
          momentum=self._norm_momentum,
          epsilon=self._norm_epsilon,
          synchronized=self._use_sync_bn)(x)
      x = tf_utils.get_activation(self._activation)(x)
      x = layers.Conv3D(
          filters=32,
          kernel_size=[1, 3, 3],
          strides=[1, 1, 1],
          use_bias=False,
          padding='same',
          kernel_initializer=self._kernel_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer)(
              x)
      x = self._norm(
          axis=self._bn_axis,
          momentum=self._norm_momentum,
          epsilon=self._norm_epsilon,
          synchronized=self._use_sync_bn)(x)
      x = tf_utils.get_activation(self._activation)(x)
      x = layers.Conv3D(
          filters=64,
          kernel_size=[1, 3, 3],
          strides=[1, 1, 1],
          use_bias=False,
          padding='same',
          kernel_initializer=self._kernel_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer)(
              x)
      x = self._norm(
          axis=self._bn_axis,
          momentum=self._norm_momentum,
          epsilon=self._norm_epsilon,
          synchronized=self._use_sync_bn)(x)
      x = tf_utils.get_activation(self._activation)(x)
    else:
      raise ValueError(f'Stem type {stem_type} not supported.')

    return x

  def _block_group(self,
                   inputs: tf.Tensor,
                   filters: int,
                   temporal_kernel_sizes: Tuple[int],
                   temporal_strides: int,
                   spatial_strides: int,
                   block_fn: Callable[
                       ...,
                       tf_keras.layers.Layer] = nn_blocks_3d.BottleneckBlock3D,
                   block_repeats: int = 1,
                   stochastic_depth_drop_rate: float = 0.0,
                   use_self_gating: bool = False,
                   name: str = 'block_group'):
    """Creates one group of blocks for the ResNet3D model.

    Args:
      inputs: A `tf.Tensor` of size `[batch, channels, height, width]`.
      filters: An `int` of number of filters for the first convolution of the
        layer.
      temporal_kernel_sizes: A tuple that specifies the temporal kernel sizes
        for each block in the current group.
      temporal_strides: An `int` of temporal strides for the first convolution
        in this group.
      spatial_strides: An `int` stride to use for the first convolution of the
        layer. If greater than 1, this layer will downsample the input.
      block_fn: Either `nn_blocks.ResidualBlock` or `nn_blocks.BottleneckBlock`.
      block_repeats: An `int` of number of blocks contained in the layer.
      stochastic_depth_drop_rate: A `float` of drop rate of the current block
        group.
      use_self_gating: A `bool` that specifies whether to apply self-gating
        module or not.
      name: A `str` name for the block.

    Returns:
      The output `tf.Tensor` of the block layer.
    """
    if len(temporal_kernel_sizes) != block_repeats:
      raise ValueError(
          'Number of elements in `temporal_kernel_sizes` must equal to `block_repeats`.'
      )

    # Only apply self-gating module in the last block.
    use_self_gating_list = [False] * (block_repeats - 1) + [use_self_gating]

    x = block_fn(
        filters=filters,
        temporal_kernel_size=temporal_kernel_sizes[0],
        temporal_strides=temporal_strides,
        spatial_strides=spatial_strides,
        stochastic_depth_drop_rate=stochastic_depth_drop_rate,
        use_self_gating=use_self_gating_list[0],
        se_ratio=self._se_ratio,
        kernel_initializer=self._kernel_initializer,
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer,
        activation=self._activation,
        use_sync_bn=self._use_sync_bn,
        norm_momentum=self._norm_momentum,
        norm_epsilon=self._norm_epsilon)(
            inputs)

    for i in range(1, block_repeats):
      x = block_fn(
          filters=filters,
          temporal_kernel_size=temporal_kernel_sizes[i],
          temporal_strides=1,
          spatial_strides=1,
          stochastic_depth_drop_rate=stochastic_depth_drop_rate,
          use_self_gating=use_self_gating_list[i],
          se_ratio=self._se_ratio,
          kernel_initializer=self._kernel_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer,
          activation=self._activation,
          use_sync_bn=self._use_sync_bn,
          norm_momentum=self._norm_momentum,
          norm_epsilon=self._norm_epsilon)(
              x)

    return tf.identity(x, name=name)

  def get_config(self):
    config_dict = {
        'model_id': self._model_id,
        'temporal_strides': self._temporal_strides,
        'temporal_kernel_sizes': self._temporal_kernel_sizes,
        'stem_type': self._stem_type,
        'stem_conv_temporal_kernel_size': self._stem_conv_temporal_kernel_size,
        'stem_conv_temporal_stride': self._stem_conv_temporal_stride,
        'stem_pool_temporal_stride': self._stem_pool_temporal_stride,
        'use_self_gating': self._use_self_gating,
        'se_ratio': self._se_ratio,
        'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
        'activation': self._activation,
        'use_sync_bn': self._use_sync_bn,
        'norm_momentum': self._norm_momentum,
        'norm_epsilon': self._norm_epsilon,
        'kernel_initializer': self._kernel_initializer,
        'kernel_regularizer': self._kernel_regularizer,
        'bias_regularizer': self._bias_regularizer,
    }
    return config_dict

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)

  @property
  def output_specs(self):
    """A dict of {level: TensorShape} pairs for the model output."""
    return self._output_specs


@factory.register_backbone_builder('resnet_3d')
def build_resnet3d(
    input_specs: tf_keras.layers.InputSpec,
    backbone_config: hyperparams.Config,
    norm_activation_config: hyperparams.Config,
    l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None
) -> tf_keras.Model:
  """Builds ResNet 3d backbone from a config."""
  backbone_cfg = backbone_config.get()

  # Flatten configs before passing to the backbone.
  temporal_strides = []
  temporal_kernel_sizes = []
  use_self_gating = []
  for block_spec in backbone_cfg.block_specs:
    temporal_strides.append(block_spec.temporal_strides)
    temporal_kernel_sizes.append(block_spec.temporal_kernel_sizes)
    use_self_gating.append(block_spec.use_self_gating)

  return ResNet3D(
      model_id=backbone_cfg.model_id,
      temporal_strides=temporal_strides,
      temporal_kernel_sizes=temporal_kernel_sizes,
      use_self_gating=use_self_gating,
      input_specs=input_specs,
      stem_type=backbone_cfg.stem_type,
      stem_conv_temporal_kernel_size=backbone_cfg
      .stem_conv_temporal_kernel_size,
      stem_conv_temporal_stride=backbone_cfg.stem_conv_temporal_stride,
      stem_pool_temporal_stride=backbone_cfg.stem_pool_temporal_stride,
      init_stochastic_depth_rate=backbone_cfg.stochastic_depth_drop_rate,
      se_ratio=backbone_cfg.se_ratio,
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer)


@factory.register_backbone_builder('resnet_3d_rs')
def build_resnet3d_rs(
    input_specs: tf_keras.layers.InputSpec,
    backbone_config: hyperparams.Config,
    norm_activation_config: hyperparams.Config,
    l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None
) -> tf_keras.Model:
  """Builds ResNet-3D-RS backbone from a config."""
  backbone_cfg = backbone_config.get()

  # Flatten configs before passing to the backbone.
  temporal_strides = []
  temporal_kernel_sizes = []
  use_self_gating = []
  for i, block_spec in enumerate(backbone_cfg.block_specs):
    temporal_strides.append(block_spec.temporal_strides)
    use_self_gating.append(block_spec.use_self_gating)
    block_repeats_i = RESNET_SPECS[backbone_cfg.model_id][i][-1]
    temporal_kernel_sizes.append(list(block_spec.temporal_kernel_sizes) *
                                 block_repeats_i)
  return ResNet3D(
      model_id=backbone_cfg.model_id,
      temporal_strides=temporal_strides,
      temporal_kernel_sizes=temporal_kernel_sizes,
      use_self_gating=use_self_gating,
      input_specs=input_specs,
      stem_type=backbone_cfg.stem_type,
      stem_conv_temporal_kernel_size=backbone_cfg
      .stem_conv_temporal_kernel_size,
      stem_conv_temporal_stride=backbone_cfg.stem_conv_temporal_stride,
      stem_pool_temporal_stride=backbone_cfg.stem_pool_temporal_stride,
      init_stochastic_depth_rate=backbone_cfg.stochastic_depth_drop_rate,
      se_ratio=backbone_cfg.se_ratio,
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer)