# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Detection input and model functions for serving/inference."""

import math
from typing import Mapping, Tuple

from absl import logging
import tensorflow as tf, tf_keras

from official.vision import configs
from official.vision.modeling import factory
from official.vision.ops import anchor
from official.vision.ops import box_ops
from official.vision.ops import preprocess_ops
from official.vision.serving import export_base


class DetectionModule(export_base.ExportModule):
  """Detection Module."""

  @property
  def _padded_size(self):
    if self.params.task.train_data.parser.pad:
      return preprocess_ops.compute_padded_size(
          self._input_image_size, 2**self.params.task.model.max_level
      )
    else:
      return self._input_image_size

  def _build_model(self):

    nms_versions_supporting_dynamic_batch_size = {'batched', 'v2', 'v3'}
    nms_version = self.params.task.model.detection_generator.nms_version
    if (self._batch_size is None and
        nms_version not in nms_versions_supporting_dynamic_batch_size):
      logging.info('nms_version is set to `batched` because `%s` '
                   'does not support with dynamic batch size.', nms_version)
      self.params.task.model.detection_generator.nms_version = 'batched'

    input_specs = tf_keras.layers.InputSpec(shape=[
        self._batch_size, *self._padded_size, 3])

    if isinstance(self.params.task.model, configs.maskrcnn.MaskRCNN):
      model = factory.build_maskrcnn(
          input_specs=input_specs, model_config=self.params.task.model)
    elif isinstance(self.params.task.model, configs.retinanet.RetinaNet):
      model = factory.build_retinanet(
          input_specs=input_specs, model_config=self.params.task.model)
    else:
      raise ValueError('Detection module not implemented for {} model.'.format(
          type(self.params.task.model)))

    return model

  def _build_anchor_boxes(self):
    """Builds and returns anchor boxes."""
    model_params = self.params.task.model
    input_anchor = anchor.build_anchor_generator(
        min_level=model_params.min_level,
        max_level=model_params.max_level,
        num_scales=model_params.anchor.num_scales,
        aspect_ratios=model_params.anchor.aspect_ratios,
        anchor_size=model_params.anchor.anchor_size)
    return input_anchor(image_size=self._padded_size)

  def _build_inputs(self, image):
    """Builds detection model inputs for serving."""
    # Normalizes image with mean and std pixel values.
    image = preprocess_ops.normalize_image(
        image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB)

    image, image_info = preprocess_ops.resize_and_crop_image(
        image,
        self._input_image_size,
        padded_size=self._padded_size,
        aug_scale_min=1.0,
        aug_scale_max=1.0,
        keep_aspect_ratio=self.params.task.train_data.parser.keep_aspect_ratio,
    )
    anchor_boxes = self._build_anchor_boxes()

    return image, anchor_boxes, image_info

  def _normalize_coordinates(self, detections_dict, dict_keys, image_info):
    """Normalizes detection coordinates between 0 and 1.

    Args:
      detections_dict: Dictionary containing the output of the model prediction.
      dict_keys: Key names corresponding to the tensors of the output dictionary
        that we want to update.
      image_info: Tensor containing the details of the image resizing.

    Returns:
      detections_dict: Updated detection dictionary.
    """
    for key in dict_keys:
      if key not in detections_dict:
        continue
      detection_boxes = detections_dict[key] / tf.tile(
          image_info[:, 2:3, :], [1, 1, 2]
      )
      detections_dict[key] = box_ops.normalize_boxes(
          detection_boxes, image_info[:, 0:1, :]
      )
      detections_dict[key] = tf.clip_by_value(detections_dict[key], 0.0, 1.0)

    return detections_dict

  def preprocess(
      self, images: tf.Tensor
  ) -> Tuple[tf.Tensor, Mapping[str, tf.Tensor], tf.Tensor]:
    """Preprocesses inputs to be suitable for the model.

    Args:
      images: The images tensor.
    Returns:
      images: The images tensor cast to float.
      anchor_boxes: Dict mapping anchor levels to anchor boxes.
      image_info: Tensor containing the details of the image resizing.

    """
    model_params = self.params.task.model
    with tf.device('cpu:0'):
      images = tf.cast(images, dtype=tf.float32)

      # Tensor Specs for map_fn outputs (images, anchor_boxes, and image_info).
      images_spec = tf.TensorSpec(shape=self._padded_size + [3],
                                  dtype=tf.float32)

      num_anchors = model_params.anchor.num_scales * len(
          model_params.anchor.aspect_ratios) * 4
      anchor_shapes = []
      for level in range(model_params.min_level, model_params.max_level + 1):
        anchor_level_spec = tf.TensorSpec(
            shape=[
                math.ceil(self._padded_size[0] / 2**level),
                math.ceil(self._padded_size[1] / 2**level),
                num_anchors,
            ],
            dtype=tf.float32)
        anchor_shapes.append((str(level), anchor_level_spec))

      image_info_spec = tf.TensorSpec(shape=[4, 2], dtype=tf.float32)

      images, anchor_boxes, image_info = tf.nest.map_structure(
          tf.identity,
          tf.map_fn(
              self._build_inputs,
              elems=images,
              fn_output_signature=(images_spec, dict(anchor_shapes),
                                   image_info_spec),
              parallel_iterations=32))

      return images, anchor_boxes, image_info

  def serve(self, images: tf.Tensor):
    """Casts image to float and runs inference.

    Args:
      images: uint8 Tensor of shape [batch_size, None, None, 3]
    Returns:
      Tensor holding detection output logits.
    """

    # Skip image preprocessing when input_type is tflite so it is compatible
    # with TFLite quantization.
    if self._input_type != 'tflite':
      images, anchor_boxes, image_info = self.preprocess(images)
    else:
      with tf.device('cpu:0'):
        anchor_boxes = self._build_anchor_boxes()
        # image_info is a 3D tensor of shape [batch_size, 4, 2]. It is in the
        # format of [[original_height, original_width],
        # [desired_height, desired_width], [y_scale, x_scale],
        # [y_offset, x_offset]]. When input_type is tflite, input image is
        # supposed to be preprocessed already.
        image_info = tf.convert_to_tensor([[
            self._input_image_size, self._input_image_size, [1.0, 1.0], [0, 0]
        ]],
                                          dtype=tf.float32)
    input_image_shape = image_info[:, 1, :]

    # To overcome keras.Model extra limitation to save a model with layers that
    # have multiple inputs, we use `model.call` here to trigger the forward
    # path. Note that, this disables some keras magics happens in `__call__`.
    model_call_kwargs = {
        'images': images,
        'image_shape': input_image_shape,
        'anchor_boxes': anchor_boxes,
        'training': False,
    }
    if isinstance(self.params.task.model, configs.retinanet.RetinaNet):
      model_call_kwargs['output_intermediate_features'] = (
          self.params.task.export_config.output_intermediate_features
      )
    detections = self.model.call(**model_call_kwargs)

    if self.params.task.model.detection_generator.apply_nms:
      # For RetinaNet model, apply export_config.
      # TODO(huizhongc): Add export_config to fasterrcnn and maskrcnn as needed.
      if isinstance(self.params.task.model, configs.retinanet.RetinaNet):
        export_config = self.params.task.export_config
        # Normalize detection box coordinates to [0, 1].
        if export_config.output_normalized_coordinates:
          keys = ['detection_boxes', 'detection_outer_boxes']
          detections = self._normalize_coordinates(detections, keys, image_info)

        # Cast num_detections and detection_classes to float. This allows the
        # model inference to work on chain (go/chain) as chain requires floating
        # point outputs.
        if export_config.cast_num_detections_to_float:
          detections['num_detections'] = tf.cast(
              detections['num_detections'], dtype=tf.float32)
        if export_config.cast_detection_classes_to_float:
          detections['detection_classes'] = tf.cast(
              detections['detection_classes'], dtype=tf.float32)

      final_outputs = {
          'detection_boxes': detections['detection_boxes'],
          'detection_scores': detections['detection_scores'],
          'detection_classes': detections['detection_classes'],
          'num_detections': detections['num_detections']
      }
      if 'detection_outer_boxes' in detections:
        final_outputs['detection_outer_boxes'] = (
            detections['detection_outer_boxes'])
    else:
      # For RetinaNet model, apply export_config.
      if isinstance(self.params.task.model, configs.retinanet.RetinaNet):
        export_config = self.params.task.export_config
        # Normalize detection box coordinates to [0, 1].
        if export_config.output_normalized_coordinates:
          keys = ['decoded_boxes']
          detections = self._normalize_coordinates(detections, keys, image_info)
      final_outputs = {
          'decoded_boxes': detections['decoded_boxes'],
          'decoded_box_scores': detections['decoded_box_scores']
      }

    if 'detection_masks' in detections.keys():
      final_outputs['detection_masks'] = detections['detection_masks']
    if (
        isinstance(self.params.task.model, configs.retinanet.RetinaNet)
        and self.params.task.export_config.output_intermediate_features
    ):
      final_outputs.update(
          {
              k: v
              for k, v in detections.items()
              if k.startswith('backbone_') or k.startswith('decoder_')
          }
      )

    if self.params.task.model.detection_generator.nms_version != 'tflite':
      final_outputs.update({'image_info': image_info})
    return final_outputs