# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Contains definitions of 3D Residual Networks."""
from typing import Callable, List, Tuple, Optional

# Import libraries
import tensorflow as tf, tf_keras

from official.modeling import hyperparams
from official.modeling import tf_utils
from official.vision.modeling.backbones import factory
from official.vision.modeling.layers import nn_blocks_3d
from official.vision.modeling.layers import nn_layers

layers = tf_keras.layers

RESNET_SPECS = {
    50: [
        ('bottleneck3d', 64, 3),
        ('bottleneck3d', 128, 4),
        ('bottleneck3d', 256, 6),
        ('bottleneck3d', 512, 3),
    ],
    101: [
        ('bottleneck3d', 64, 3),
        ('bottleneck3d', 128, 4),
        ('bottleneck3d', 256, 23),
        ('bottleneck3d', 512, 3),
    ],
    152: [
        ('bottleneck3d', 64, 3),
        ('bottleneck3d', 128, 8),
        ('bottleneck3d', 256, 36),
        ('bottleneck3d', 512, 3),
    ],
    200: [
        ('bottleneck3d', 64, 3),
        ('bottleneck3d', 128, 24),
        ('bottleneck3d', 256, 36),
        ('bottleneck3d', 512, 3),
    ],
    270: [
        ('bottleneck3d', 64, 4),
        ('bottleneck3d', 128, 29),
        ('bottleneck3d', 256, 53),
        ('bottleneck3d', 512, 4),
    ],
    300: [
        ('bottleneck3d', 64, 4),
        ('bottleneck3d', 128, 36),
        ('bottleneck3d', 256, 54),
        ('bottleneck3d', 512, 4),
    ],
    350: [
        ('bottleneck3d', 64, 4),
        ('bottleneck3d', 128, 36),
        ('bottleneck3d', 256, 72),
        ('bottleneck3d', 512, 4),
    ],
}


@tf_keras.utils.register_keras_serializable(package='Vision')
class ResNet3D(tf_keras.Model):
  """Creates a 3D ResNet family model."""

  def __init__(
      self,
      model_id: int,
      temporal_strides: List[int],
      temporal_kernel_sizes: List[Tuple[int]],
      use_self_gating: Optional[List[int]] = None,
      input_specs: tf_keras.layers.InputSpec = layers.InputSpec(
          shape=[None, None, None, None, 3]),
      stem_type: str = 'v0',
      stem_conv_temporal_kernel_size: int = 5,
      stem_conv_temporal_stride: int = 2,
      stem_pool_temporal_stride: int = 2,
      init_stochastic_depth_rate: float = 0.0,
      activation: str = 'relu',
      se_ratio: Optional[float] = None,
      use_sync_bn: bool = False,
      norm_momentum: float = 0.99,
      norm_epsilon: float = 0.001,
      kernel_initializer: str = 'VarianceScaling',
      kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
      bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
      **kwargs):
    """Initializes a 3D ResNet model.

    Args:
      model_id: An `int` of depth of ResNet backbone model.
      temporal_strides: A list of integers that specifies the temporal strides
        for all 3d blocks.
      temporal_kernel_sizes: A list of tuples that specifies the temporal kernel
        sizes for all 3d blocks in different block groups.
      use_self_gating: A list of booleans to specify applying self-gating module
        or not in each block group. If None, self-gating is not applied.
      input_specs: A `tf_keras.layers.InputSpec` of the input tensor.
      stem_type: A `str` of stem type of ResNet. Default to `v0`. If set to
        `v1`, use ResNet-D type stem (https://arxiv.org/abs/1812.01187).
      stem_conv_temporal_kernel_size: An `int` of temporal kernel size for the
        first conv layer.
      stem_conv_temporal_stride: An `int` of temporal stride for the first conv
        layer.
      stem_pool_temporal_stride: An `int` of temporal stride for the first pool
        layer.
      init_stochastic_depth_rate: A `float` of initial stochastic depth rate.
      activation: A `str` of name of the activation function.
      se_ratio: A `float` or None. Ratio of the Squeeze-and-Excitation layer.
      use_sync_bn: If True, use synchronized batch normalization.
      norm_momentum: A `float` of normalization momentum for the moving average.
      norm_epsilon: A `float` added to variance to avoid dividing by zero.
      kernel_initializer: A str for kernel initializer of convolutional layers.
      kernel_regularizer: A `tf_keras.regularizers.Regularizer` object for
        Conv2D. Default to None.
      bias_regularizer: A `tf_keras.regularizers.Regularizer` object for Conv2D.
        Default to None.
      **kwargs: Additional keyword arguments to be passed.
    """
    self._model_id = model_id
    self._temporal_strides = temporal_strides
    self._temporal_kernel_sizes = temporal_kernel_sizes
    self._input_specs = input_specs
    self._stem_type = stem_type
    self._stem_conv_temporal_kernel_size = stem_conv_temporal_kernel_size
    self._stem_conv_temporal_stride = stem_conv_temporal_stride
    self._stem_pool_temporal_stride = stem_pool_temporal_stride
    self._use_self_gating = use_self_gating
    self._se_ratio = se_ratio
    self._init_stochastic_depth_rate = init_stochastic_depth_rate
    self._use_sync_bn = use_sync_bn
    self._activation = activation
    self._norm_momentum = norm_momentum
    self._norm_epsilon = norm_epsilon
    self._norm = layers.BatchNormalization
    self._kernel_initializer = kernel_initializer
    self._kernel_regularizer = kernel_regularizer
    self._bias_regularizer = bias_regularizer
    if tf_keras.backend.image_data_format() == 'channels_last':
      self._bn_axis = -1
    else:
      self._bn_axis = 1

    # Build ResNet3D backbone.
    inputs = tf_keras.Input(shape=input_specs.shape[1:])
    endpoints = self._build_model(inputs)
    self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}

    super(ResNet3D, self).__init__(inputs=inputs, outputs=endpoints, **kwargs)

  def _build_model(self, inputs):
    """Builds model architecture.

    Args:
      inputs: the keras input spec.

    Returns:
      endpoints: A dictionary of backbone endpoint features.
    """
    # Build stem.
    x = self._build_stem(inputs, stem_type=self._stem_type)

    temporal_kernel_size = 1 if self._stem_pool_temporal_stride == 1 else 3
    x = layers.MaxPool3D(
        pool_size=[temporal_kernel_size, 3, 3],
        strides=[self._stem_pool_temporal_stride, 2, 2],
        padding='same')(x)

    # Build intermediate blocks and endpoints.
    resnet_specs = RESNET_SPECS[self._model_id]
    if len(self._temporal_strides) != len(resnet_specs) or len(
        self._temporal_kernel_sizes) != len(resnet_specs):
      raise ValueError(
          'Number of blocks in temporal specs should equal to resnet_specs.')

    endpoints = {}
    for i, resnet_spec in enumerate(resnet_specs):
      if resnet_spec[0] == 'bottleneck3d':
        block_fn = nn_blocks_3d.BottleneckBlock3D
      else:
        raise ValueError('Block fn `{}` is not supported.'.format(
            resnet_spec[0]))

      use_self_gating = (
          self._use_self_gating[i] if self._use_self_gating else False)
      x = self._block_group(
          inputs=x,
          filters=resnet_spec[1],
          temporal_kernel_sizes=self._temporal_kernel_sizes[i],
          temporal_strides=self._temporal_strides[i],
          spatial_strides=(1 if i == 0 else 2),
          block_fn=block_fn,
          block_repeats=resnet_spec[2],
          stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
              self._init_stochastic_depth_rate, i + 2, 5),
          use_self_gating=use_self_gating,
          name='block_group_l{}'.format(i + 2))
      endpoints[str(i + 2)] = x

    return endpoints

  def _build_stem(self, inputs, stem_type):
    """Builds stem layer."""
    # Build stem.
    if stem_type == 'v0':
      x = layers.Conv3D(
          filters=64,
          kernel_size=[self._stem_conv_temporal_kernel_size, 7, 7],
          strides=[self._stem_conv_temporal_stride, 2, 2],
          use_bias=False,
          padding='same',
          kernel_initializer=self._kernel_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer)(
              inputs)
      x = self._norm(
          axis=self._bn_axis,
          momentum=self._norm_momentum,
          epsilon=self._norm_epsilon,
          synchronized=self._use_sync_bn)(x)
      x = tf_utils.get_activation(self._activation)(x)
    elif stem_type == 'v1':
      x = layers.Conv3D(
          filters=32,
          kernel_size=[self._stem_conv_temporal_kernel_size, 3, 3],
          strides=[self._stem_conv_temporal_stride, 2, 2],
          use_bias=False,
          padding='same',
          kernel_initializer=self._kernel_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer)(
              inputs)
      x = self._norm(
          axis=self._bn_axis,
          momentum=self._norm_momentum,
          epsilon=self._norm_epsilon,
          synchronized=self._use_sync_bn)(x)
      x = tf_utils.get_activation(self._activation)(x)
      x = layers.Conv3D(
          filters=32,
          kernel_size=[1, 3, 3],
          strides=[1, 1, 1],
          use_bias=False,
          padding='same',
          kernel_initializer=self._kernel_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer)(
              x)
      x = self._norm(
          axis=self._bn_axis,
          momentum=self._norm_momentum,
          epsilon=self._norm_epsilon,
          synchronized=self._use_sync_bn)(x)
      x = tf_utils.get_activation(self._activation)(x)
      x = layers.Conv3D(
          filters=64,
          kernel_size=[1, 3, 3],
          strides=[1, 1, 1],
          use_bias=False,
          padding='same',
          kernel_initializer=self._kernel_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer)(
              x)
      x = self._norm(
          axis=self._bn_axis,
          momentum=self._norm_momentum,
          epsilon=self._norm_epsilon,
          synchronized=self._use_sync_bn)(x)
      x = tf_utils.get_activation(self._activation)(x)
    else:
      raise ValueError(f'Stem type {stem_type} not supported.')

    return x

  def _block_group(self,
                   inputs: tf.Tensor,
                   filters: int,
                   temporal_kernel_sizes: Tuple[int],
                   temporal_strides: int,
                   spatial_strides: int,
                   block_fn: Callable[
                       ...,
                       tf_keras.layers.Layer] = nn_blocks_3d.BottleneckBlock3D,
                   block_repeats: int = 1,
                   stochastic_depth_drop_rate: float = 0.0,
                   use_self_gating: bool = False,
                   name: str = 'block_group'):
    """Creates one group of blocks for the ResNet3D model.

    Args:
      inputs: A `tf.Tensor` of size `[batch, channels, height, width]`.
      filters: An `int` of number of filters for the first convolution of the
        layer.
      temporal_kernel_sizes: A tuple that specifies the temporal kernel sizes
        for each block in the current group.
      temporal_strides: An `int` of temporal strides for the first convolution
        in this group.
      spatial_strides: An `int` stride to use for the first convolution of the
        layer. If greater than 1, this layer will downsample the input.
      block_fn: Either `nn_blocks.ResidualBlock` or `nn_blocks.BottleneckBlock`.
      block_repeats: An `int` of number of blocks contained in the layer.
      stochastic_depth_drop_rate: A `float` of drop rate of the current block
        group.
      use_self_gating: A `bool` that specifies whether to apply self-gating
        module or not.
      name: A `str` name for the block.

    Returns:
      The output `tf.Tensor` of the block layer.
    """
    if len(temporal_kernel_sizes) != block_repeats:
      raise ValueError(
          'Number of elements in `temporal_kernel_sizes` must equal to `block_repeats`.'
      )

    # Only apply self-gating module in the last block.
    use_self_gating_list = [False] * (block_repeats - 1) + [use_self_gating]

    x = block_fn(
        filters=filters,
        temporal_kernel_size=temporal_kernel_sizes[0],
        temporal_strides=temporal_strides,
        spatial_strides=spatial_strides,
        stochastic_depth_drop_rate=stochastic_depth_drop_rate,
        use_self_gating=use_self_gating_list[0],
        se_ratio=self._se_ratio,
        kernel_initializer=self._kernel_initializer,
        kernel_regularizer=self._kernel_regularizer,
        bias_regularizer=self._bias_regularizer,
        activation=self._activation,
        use_sync_bn=self._use_sync_bn,
        norm_momentum=self._norm_momentum,
        norm_epsilon=self._norm_epsilon)(
            inputs)

    for i in range(1, block_repeats):
      x = block_fn(
          filters=filters,
          temporal_kernel_size=temporal_kernel_sizes[i],
          temporal_strides=1,
          spatial_strides=1,
          stochastic_depth_drop_rate=stochastic_depth_drop_rate,
          use_self_gating=use_self_gating_list[i],
          se_ratio=self._se_ratio,
          kernel_initializer=self._kernel_initializer,
          kernel_regularizer=self._kernel_regularizer,
          bias_regularizer=self._bias_regularizer,
          activation=self._activation,
          use_sync_bn=self._use_sync_bn,
          norm_momentum=self._norm_momentum,
          norm_epsilon=self._norm_epsilon)(
              x)

    return tf.identity(x, name=name)

  def get_config(self):
    config_dict = {
        'model_id': self._model_id,
        'temporal_strides': self._temporal_strides,
        'temporal_kernel_sizes': self._temporal_kernel_sizes,
        'stem_type': self._stem_type,
        'stem_conv_temporal_kernel_size': self._stem_conv_temporal_kernel_size,
        'stem_conv_temporal_stride': self._stem_conv_temporal_stride,
        'stem_pool_temporal_stride': self._stem_pool_temporal_stride,
        'use_self_gating': self._use_self_gating,
        'se_ratio': self._se_ratio,
        'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
        'activation': self._activation,
        'use_sync_bn': self._use_sync_bn,
        'norm_momentum': self._norm_momentum,
        'norm_epsilon': self._norm_epsilon,
        'kernel_initializer': self._kernel_initializer,
        'kernel_regularizer': self._kernel_regularizer,
        'bias_regularizer': self._bias_regularizer,
    }
    return config_dict

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)

  @property
  def output_specs(self):
    """A dict of {level: TensorShape} pairs for the model output."""
    return self._output_specs


@factory.register_backbone_builder('resnet_3d')
def build_resnet3d(
    input_specs: tf_keras.layers.InputSpec,
    backbone_config: hyperparams.Config,
    norm_activation_config: hyperparams.Config,
    l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None
) -> tf_keras.Model:
  """Builds ResNet 3d backbone from a config."""
  backbone_cfg = backbone_config.get()

  # Flatten configs before passing to the backbone.
  temporal_strides = []
  temporal_kernel_sizes = []
  use_self_gating = []
  for block_spec in backbone_cfg.block_specs:
    temporal_strides.append(block_spec.temporal_strides)
    temporal_kernel_sizes.append(block_spec.temporal_kernel_sizes)
    use_self_gating.append(block_spec.use_self_gating)

  return ResNet3D(
      model_id=backbone_cfg.model_id,
      temporal_strides=temporal_strides,
      temporal_kernel_sizes=temporal_kernel_sizes,
      use_self_gating=use_self_gating,
      input_specs=input_specs,
      stem_type=backbone_cfg.stem_type,
      stem_conv_temporal_kernel_size=backbone_cfg
      .stem_conv_temporal_kernel_size,
      stem_conv_temporal_stride=backbone_cfg.stem_conv_temporal_stride,
      stem_pool_temporal_stride=backbone_cfg.stem_pool_temporal_stride,
      init_stochastic_depth_rate=backbone_cfg.stochastic_depth_drop_rate,
      se_ratio=backbone_cfg.se_ratio,
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer)


@factory.register_backbone_builder('resnet_3d_rs')
def build_resnet3d_rs(
    input_specs: tf_keras.layers.InputSpec,
    backbone_config: hyperparams.Config,
    norm_activation_config: hyperparams.Config,
    l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None
) -> tf_keras.Model:
  """Builds ResNet-3D-RS backbone from a config."""
  backbone_cfg = backbone_config.get()

  # Flatten configs before passing to the backbone.
  temporal_strides = []
  temporal_kernel_sizes = []
  use_self_gating = []
  for i, block_spec in enumerate(backbone_cfg.block_specs):
    temporal_strides.append(block_spec.temporal_strides)
    use_self_gating.append(block_spec.use_self_gating)
    block_repeats_i = RESNET_SPECS[backbone_cfg.model_id][i][-1]
    temporal_kernel_sizes.append(list(block_spec.temporal_kernel_sizes) *
                                 block_repeats_i)
  return ResNet3D(
      model_id=backbone_cfg.model_id,
      temporal_strides=temporal_strides,
      temporal_kernel_sizes=temporal_kernel_sizes,
      use_self_gating=use_self_gating,
      input_specs=input_specs,
      stem_type=backbone_cfg.stem_type,
      stem_conv_temporal_kernel_size=backbone_cfg
      .stem_conv_temporal_kernel_size,
      stem_conv_temporal_stride=backbone_cfg.stem_conv_temporal_stride,
      stem_pool_temporal_stride=backbone_cfg.stem_pool_temporal_stride,
      init_stochastic_depth_rate=backbone_cfg.stochastic_depth_drop_rate,
      se_ratio=backbone_cfg.se_ratio,
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer)