# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Factory methods to build models."""

from typing import Optional

import tensorflow as tf, tf_keras

from official.vision.configs import image_classification as classification_cfg
from official.vision.configs import maskrcnn as maskrcnn_cfg
from official.vision.configs import retinanet as retinanet_cfg
from official.vision.configs import semantic_segmentation as segmentation_cfg
from official.vision.modeling import backbones
from official.vision.modeling import classification_model
from official.vision.modeling import decoders
from official.vision.modeling import maskrcnn_model
from official.vision.modeling import retinanet_model
from official.vision.modeling import segmentation_model
from official.vision.modeling.heads import dense_prediction_heads
from official.vision.modeling.heads import instance_heads
from official.vision.modeling.heads import segmentation_heads
from official.vision.modeling.layers import detection_generator
from official.vision.modeling.layers import mask_sampler
from official.vision.modeling.layers import roi_aligner
from official.vision.modeling.layers import roi_generator
from official.vision.modeling.layers import roi_sampler


def build_classification_model(
    input_specs: tf_keras.layers.InputSpec,
    model_config: classification_cfg.ImageClassificationModel,
    l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
    skip_logits_layer: bool = False,
    backbone: Optional[tf_keras.Model] = None) -> tf_keras.Model:
  """Builds the classification model."""
  norm_activation_config = model_config.norm_activation
  if not backbone:
    backbone = backbones.factory.build_backbone(
        input_specs=input_specs,
        backbone_config=model_config.backbone,
        norm_activation_config=norm_activation_config,
        l2_regularizer=l2_regularizer)

  model = classification_model.ClassificationModel(
      backbone=backbone,
      num_classes=model_config.num_classes,
      input_specs=input_specs,
      dropout_rate=model_config.dropout_rate,
      kernel_initializer=model_config.kernel_initializer,
      kernel_regularizer=l2_regularizer,
      add_head_batch_norm=model_config.add_head_batch_norm,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      skip_logits_layer=skip_logits_layer)
  return model


def build_maskrcnn(input_specs: tf_keras.layers.InputSpec,
                   model_config: maskrcnn_cfg.MaskRCNN,
                   l2_regularizer: Optional[
                       tf_keras.regularizers.Regularizer] = None,
                   backbone: Optional[tf_keras.Model] = None,
                   decoder: Optional[tf_keras.Model] = None) -> tf_keras.Model:
  """Builds Mask R-CNN model."""
  norm_activation_config = model_config.norm_activation
  if not backbone:
    backbone = backbones.factory.build_backbone(
        input_specs=input_specs,
        backbone_config=model_config.backbone,
        norm_activation_config=norm_activation_config,
        l2_regularizer=l2_regularizer)
  backbone_features = backbone(tf_keras.Input(input_specs.shape[1:]))

  if not decoder:
    decoder = decoders.factory.build_decoder(
        input_specs=backbone.output_specs,
        model_config=model_config,
        l2_regularizer=l2_regularizer)

  rpn_head_config = model_config.rpn_head
  roi_generator_config = model_config.roi_generator
  roi_sampler_config = model_config.roi_sampler
  roi_aligner_config = model_config.roi_aligner
  detection_head_config = model_config.detection_head
  generator_config = model_config.detection_generator
  num_anchors_per_location = (
      len(model_config.anchor.aspect_ratios) * model_config.anchor.num_scales)

  rpn_head = dense_prediction_heads.RPNHead(
      min_level=model_config.min_level,
      max_level=model_config.max_level,
      num_anchors_per_location=num_anchors_per_location,
      num_convs=rpn_head_config.num_convs,
      num_filters=rpn_head_config.num_filters,
      use_separable_conv=rpn_head_config.use_separable_conv,
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer)

  detection_head = instance_heads.DetectionHead(
      num_classes=model_config.num_classes,
      num_convs=detection_head_config.num_convs,
      num_filters=detection_head_config.num_filters,
      use_separable_conv=detection_head_config.use_separable_conv,
      num_fcs=detection_head_config.num_fcs,
      fc_dims=detection_head_config.fc_dims,
      class_agnostic_bbox_pred=detection_head_config.class_agnostic_bbox_pred,
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer,
      name='detection_head')

  if decoder:
    decoder_features = decoder(backbone_features)
    rpn_head(decoder_features)

  if roi_sampler_config.cascade_iou_thresholds:
    detection_head_cascade = [detection_head]
    for cascade_num in range(len(roi_sampler_config.cascade_iou_thresholds)):
      detection_head = instance_heads.DetectionHead(
          num_classes=model_config.num_classes,
          num_convs=detection_head_config.num_convs,
          num_filters=detection_head_config.num_filters,
          use_separable_conv=detection_head_config.use_separable_conv,
          num_fcs=detection_head_config.num_fcs,
          fc_dims=detection_head_config.fc_dims,
          class_agnostic_bbox_pred=detection_head_config
          .class_agnostic_bbox_pred,
          activation=norm_activation_config.activation,
          use_sync_bn=norm_activation_config.use_sync_bn,
          norm_momentum=norm_activation_config.norm_momentum,
          norm_epsilon=norm_activation_config.norm_epsilon,
          kernel_regularizer=l2_regularizer,
          name='detection_head_{}'.format(cascade_num + 1))

      detection_head_cascade.append(detection_head)
    detection_head = detection_head_cascade

  roi_generator_obj = roi_generator.MultilevelROIGenerator(
      pre_nms_top_k=roi_generator_config.pre_nms_top_k,
      pre_nms_score_threshold=roi_generator_config.pre_nms_score_threshold,
      pre_nms_min_size_threshold=(
          roi_generator_config.pre_nms_min_size_threshold),
      nms_iou_threshold=roi_generator_config.nms_iou_threshold,
      num_proposals=roi_generator_config.num_proposals,
      test_pre_nms_top_k=roi_generator_config.test_pre_nms_top_k,
      test_pre_nms_score_threshold=(
          roi_generator_config.test_pre_nms_score_threshold),
      test_pre_nms_min_size_threshold=(
          roi_generator_config.test_pre_nms_min_size_threshold),
      test_nms_iou_threshold=roi_generator_config.test_nms_iou_threshold,
      test_num_proposals=roi_generator_config.test_num_proposals,
      use_batched_nms=roi_generator_config.use_batched_nms)

  roi_sampler_cascade = []
  roi_sampler_obj = roi_sampler.ROISampler(
      mix_gt_boxes=roi_sampler_config.mix_gt_boxes,
      num_sampled_rois=roi_sampler_config.num_sampled_rois,
      foreground_fraction=roi_sampler_config.foreground_fraction,
      foreground_iou_threshold=roi_sampler_config.foreground_iou_threshold,
      background_iou_high_threshold=(
          roi_sampler_config.background_iou_high_threshold),
      background_iou_low_threshold=(
          roi_sampler_config.background_iou_low_threshold))
  roi_sampler_cascade.append(roi_sampler_obj)
  # Initialize additional roi simplers for cascade heads.
  if roi_sampler_config.cascade_iou_thresholds:
    for iou in roi_sampler_config.cascade_iou_thresholds:
      roi_sampler_obj = roi_sampler.ROISampler(
          mix_gt_boxes=False,
          num_sampled_rois=roi_sampler_config.num_sampled_rois,
          foreground_iou_threshold=iou,
          background_iou_high_threshold=iou,
          background_iou_low_threshold=0.0,
          skip_subsampling=True)
      roi_sampler_cascade.append(roi_sampler_obj)

  roi_aligner_obj = roi_aligner.MultilevelROIAligner(
      crop_size=roi_aligner_config.crop_size,
      sample_offset=roi_aligner_config.sample_offset)

  detection_generator_obj = detection_generator.DetectionGenerator(
      apply_nms=generator_config.apply_nms,
      pre_nms_top_k=generator_config.pre_nms_top_k,
      pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
      nms_iou_threshold=generator_config.nms_iou_threshold,
      max_num_detections=generator_config.max_num_detections,
      nms_version=generator_config.nms_version,
      use_cpu_nms=generator_config.use_cpu_nms,
      soft_nms_sigma=generator_config.soft_nms_sigma,
      use_sigmoid_probability=generator_config.use_sigmoid_probability)

  if model_config.include_mask:
    mask_head = instance_heads.MaskHead(
        num_classes=model_config.num_classes,
        upsample_factor=model_config.mask_head.upsample_factor,
        num_convs=model_config.mask_head.num_convs,
        num_filters=model_config.mask_head.num_filters,
        use_separable_conv=model_config.mask_head.use_separable_conv,
        activation=model_config.norm_activation.activation,
        norm_momentum=model_config.norm_activation.norm_momentum,
        norm_epsilon=model_config.norm_activation.norm_epsilon,
        kernel_regularizer=l2_regularizer,
        class_agnostic=model_config.mask_head.class_agnostic)

    mask_sampler_obj = mask_sampler.MaskSampler(
        mask_target_size=(
            model_config.mask_roi_aligner.crop_size *
            model_config.mask_head.upsample_factor),
        num_sampled_masks=model_config.mask_sampler.num_sampled_masks)

    mask_roi_aligner_obj = roi_aligner.MultilevelROIAligner(
        crop_size=model_config.mask_roi_aligner.crop_size,
        sample_offset=model_config.mask_roi_aligner.sample_offset)
  else:
    mask_head = None
    mask_sampler_obj = None
    mask_roi_aligner_obj = None

  model = maskrcnn_model.MaskRCNNModel(
      backbone=backbone,
      decoder=decoder,
      rpn_head=rpn_head,
      detection_head=detection_head,
      roi_generator=roi_generator_obj,
      roi_sampler=roi_sampler_cascade,
      roi_aligner=roi_aligner_obj,
      detection_generator=detection_generator_obj,
      mask_head=mask_head,
      mask_sampler=mask_sampler_obj,
      mask_roi_aligner=mask_roi_aligner_obj,
      class_agnostic_bbox_pred=detection_head_config.class_agnostic_bbox_pred,
      cascade_class_ensemble=detection_head_config.cascade_class_ensemble,
      min_level=model_config.min_level,
      max_level=model_config.max_level,
      num_scales=model_config.anchor.num_scales,
      aspect_ratios=model_config.anchor.aspect_ratios,
      anchor_size=model_config.anchor.anchor_size,
      outer_boxes_scale=model_config.outer_boxes_scale)
  return model


def build_retinanet(
    input_specs: tf_keras.layers.InputSpec,
    model_config: retinanet_cfg.RetinaNet,
    l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
    backbone: Optional[tf_keras.Model] = None,
    decoder: Optional[tf_keras.Model] = None
) -> tf_keras.Model:
  """Builds RetinaNet model."""
  norm_activation_config = model_config.norm_activation
  if not backbone:
    backbone = backbones.factory.build_backbone(
        input_specs=input_specs,
        backbone_config=model_config.backbone,
        norm_activation_config=norm_activation_config,
        l2_regularizer=l2_regularizer)
  backbone_features = backbone(tf_keras.Input(input_specs.shape[1:]))

  if not decoder:
    decoder = decoders.factory.build_decoder(
        input_specs=backbone.output_specs,
        model_config=model_config,
        l2_regularizer=l2_regularizer)

  head_config = model_config.head
  generator_config = model_config.detection_generator
  num_anchors_per_location = (
      len(model_config.anchor.aspect_ratios) * model_config.anchor.num_scales)

  head = dense_prediction_heads.RetinaNetHead(
      min_level=model_config.min_level,
      max_level=model_config.max_level,
      num_classes=model_config.num_classes,
      num_anchors_per_location=num_anchors_per_location,
      num_convs=head_config.num_convs,
      num_filters=head_config.num_filters,
      attribute_heads=[
          cfg.as_dict() for cfg in (head_config.attribute_heads or [])
      ],
      share_classification_heads=head_config.share_classification_heads,
      use_separable_conv=head_config.use_separable_conv,
      activation=norm_activation_config.activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer,
      share_level_convs=head_config.share_level_convs,
  )

  # Builds decoder and head so that their trainable weights are initialized
  if decoder:
    decoder_features = decoder(backbone_features)
    _ = head(decoder_features)

  # Add `input_image_size` into `tflite_post_processing_config`.
  tflite_post_processing_config = (
      generator_config.tflite_post_processing.as_dict()
  )
  tflite_post_processing_config['input_image_size'] = (
      input_specs.shape[1],
      input_specs.shape[2],
  )
  detection_generator_obj = detection_generator.MultilevelDetectionGenerator(
      apply_nms=generator_config.apply_nms,
      pre_nms_top_k=generator_config.pre_nms_top_k,
      pre_nms_score_threshold=generator_config.pre_nms_score_threshold,
      nms_iou_threshold=generator_config.nms_iou_threshold,
      max_num_detections=generator_config.max_num_detections,
      nms_version=generator_config.nms_version,
      use_cpu_nms=generator_config.use_cpu_nms,
      soft_nms_sigma=generator_config.soft_nms_sigma,
      tflite_post_processing_config=tflite_post_processing_config,
      return_decoded=generator_config.return_decoded,
      use_class_agnostic_nms=generator_config.use_class_agnostic_nms,
      box_coder_weights=generator_config.box_coder_weights,
  )

  model = retinanet_model.RetinaNetModel(
      backbone,
      decoder,
      head,
      detection_generator_obj,
      min_level=model_config.min_level,
      max_level=model_config.max_level,
      num_scales=model_config.anchor.num_scales,
      aspect_ratios=model_config.anchor.aspect_ratios,
      anchor_size=model_config.anchor.anchor_size)
  return model


def build_segmentation_model(
    input_specs: tf_keras.layers.InputSpec,
    model_config: segmentation_cfg.SemanticSegmentationModel,
    l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None,
    backbone: Optional[tf_keras.Model] = None,
    decoder: Optional[tf_keras.Model] = None
) -> tf_keras.Model:
  """Builds Segmentation model."""
  norm_activation_config = model_config.norm_activation
  if not backbone:
    backbone = backbones.factory.build_backbone(
        input_specs=input_specs,
        backbone_config=model_config.backbone,
        norm_activation_config=norm_activation_config,
        l2_regularizer=l2_regularizer)

  if not decoder:
    decoder = decoders.factory.build_decoder(
        input_specs=backbone.output_specs,
        model_config=model_config,
        l2_regularizer=l2_regularizer)

  head_config = model_config.head

  head = segmentation_heads.SegmentationHead(
      num_classes=model_config.num_classes,
      level=head_config.level,
      num_convs=head_config.num_convs,
      prediction_kernel_size=head_config.prediction_kernel_size,
      num_filters=head_config.num_filters,
      use_depthwise_convolution=head_config.use_depthwise_convolution,
      upsample_factor=head_config.upsample_factor,
      feature_fusion=head_config.feature_fusion,
      low_level=head_config.low_level,
      low_level_num_filters=head_config.low_level_num_filters,
      activation=norm_activation_config.activation,
      logit_activation=head_config.logit_activation,
      use_sync_bn=norm_activation_config.use_sync_bn,
      norm_momentum=norm_activation_config.norm_momentum,
      norm_epsilon=norm_activation_config.norm_epsilon,
      kernel_regularizer=l2_regularizer)

  mask_scoring_head = None
  if model_config.mask_scoring_head:
    mask_scoring_head = segmentation_heads.MaskScoring(
        num_classes=model_config.num_classes,
        **model_config.mask_scoring_head.as_dict(),
        activation=norm_activation_config.activation,
        use_sync_bn=norm_activation_config.use_sync_bn,
        norm_momentum=norm_activation_config.norm_momentum,
        norm_epsilon=norm_activation_config.norm_epsilon,
        kernel_regularizer=l2_regularizer)

  model = segmentation_model.SegmentationModel(
      backbone, decoder, head, mask_scoring_head=mask_scoring_head)
  return model