# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """RetinaNet.""" from typing import Any, Mapping, List, Optional, Union, Sequence # Import libraries import tensorflow as tf, tf_keras from official.vision.ops import anchor @tf_keras.utils.register_keras_serializable(package='Vision') class RetinaNetModel(tf_keras.Model): """The RetinaNet model class.""" def __init__(self, backbone: tf_keras.Model, decoder: tf_keras.Model, head: tf_keras.layers.Layer, detection_generator: tf_keras.layers.Layer, min_level: Optional[int] = None, max_level: Optional[int] = None, num_scales: Optional[int] = None, aspect_ratios: Optional[List[float]] = None, anchor_size: Optional[float] = None, **kwargs): """Detection initialization function. Args: backbone: `tf_keras.Model` a backbone network. decoder: `tf_keras.Model` a decoder network. head: `RetinaNetHead`, the RetinaNet head. detection_generator: the detection generator. min_level: Minimum level in output feature maps. max_level: Maximum level in output feature maps. num_scales: A number representing intermediate scales added on each level. For instances, num_scales=2 adds one additional intermediate anchor scales [2^0, 2^0.5] on each level. aspect_ratios: A list representing the aspect raito anchors added on each level. The number indicates the ratio of width to height. For instances, aspect_ratios=[1.0, 2.0, 0.5] adds three anchors on each scale level. anchor_size: A number representing the scale of size of the base anchor to the feature stride 2^level. **kwargs: keyword arguments to be passed. """ super(RetinaNetModel, self).__init__(**kwargs) self._config_dict = { 'backbone': backbone, 'decoder': decoder, 'head': head, 'detection_generator': detection_generator, 'min_level': min_level, 'max_level': max_level, 'num_scales': num_scales, 'aspect_ratios': aspect_ratios, 'anchor_size': anchor_size, } self._backbone = backbone self._decoder = decoder self._head = head self._detection_generator = detection_generator def call(self, images: Union[tf.Tensor, Sequence[tf.Tensor]], image_shape: Optional[tf.Tensor] = None, anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None, output_intermediate_features: bool = False, training: bool = None) -> Mapping[str, tf.Tensor]: """Forward pass of the RetinaNet model. Args: images: `Tensor` or a sequence of `Tensor`, the input batched images to the backbone network, whose shape(s) is [batch, height, width, 3]. If it is a sequence of `Tensor`, we will assume the anchors are generated based on the shape of the first image(s). image_shape: `Tensor`, the actual shape of the input images, whose shape is [batch, 2] where the last dimension is [height, width]. Note that this is the actual image shape excluding paddings. For example, images in the batch may be resized into different shapes before padding to the fixed size. anchor_boxes: a dict of tensors which includes multilevel anchors. - key: `str`, the level of the multilevel predictions. - values: `Tensor`, the anchor coordinates of a particular feature level, whose shape is [height_l, width_l, num_anchors_per_location]. output_intermediate_features: `bool` indicating whether to return the intermediate feature maps generated by backbone and decoder. training: `bool`, indicating whether it is in training mode. Returns: scores: a dict of tensors which includes scores of the predictions. - key: `str`, the level of the multilevel predictions. - values: `Tensor`, the box scores predicted from a particular feature level, whose shape is [batch, height_l, width_l, num_classes * num_anchors_per_location]. boxes: a dict of tensors which includes coordinates of the predictions. - key: `str`, the level of the multilevel predictions. - values: `Tensor`, the box coordinates predicted from a particular feature level, whose shape is [batch, height_l, width_l, 4 * num_anchors_per_location]. attributes: a dict of (attribute_name, attribute_predictions). Each attribute prediction is a dict that includes: - key: `str`, the level of the multilevel predictions. - values: `Tensor`, the attribute predictions from a particular feature level, whose shape is [batch, height_l, width_l, att_size * num_anchors_per_location]. """ outputs = {} # Feature extraction. features = self.backbone(images) if output_intermediate_features: outputs.update( {'backbone_{}'.format(k): v for k, v in features.items()}) if self.decoder: features = self.decoder(features) if output_intermediate_features: outputs.update( {'decoder_{}'.format(k): v for k, v in features.items()}) # Dense prediction. `raw_attributes` can be empty. raw_scores, raw_boxes, raw_attributes = self.head(features) if training: outputs.update({ 'cls_outputs': raw_scores, 'box_outputs': raw_boxes, }) if raw_attributes: outputs.update({'attribute_outputs': raw_attributes}) return outputs else: # Generate anchor boxes for this batch if not provided. if anchor_boxes is None: if isinstance(images, Sequence): primary_images = images[0] elif isinstance(images, tf.Tensor): primary_images = images else: raise ValueError( 'Input should be a tf.Tensor or a sequence of tf.Tensor, not {}.' .format(type(images))) _, image_height, image_width, _ = primary_images.get_shape().as_list() anchor_boxes = anchor.Anchor( min_level=self._config_dict['min_level'], max_level=self._config_dict['max_level'], num_scales=self._config_dict['num_scales'], aspect_ratios=self._config_dict['aspect_ratios'], anchor_size=self._config_dict['anchor_size'], image_size=(image_height, image_width)).multilevel_boxes for l in anchor_boxes: anchor_boxes[l] = tf.tile( tf.expand_dims(anchor_boxes[l], axis=0), [tf.shape(primary_images)[0], 1, 1, 1]) # Post-processing. final_results = self.detection_generator(raw_boxes, raw_scores, anchor_boxes, image_shape, raw_attributes) outputs.update({ 'cls_outputs': raw_scores, 'box_outputs': raw_boxes, }) def _update_decoded_results(): outputs.update({ 'decoded_boxes': final_results['decoded_boxes'], 'decoded_box_scores': final_results['decoded_box_scores'], }) if final_results.get('decoded_box_attributes') is not None: outputs['decoded_box_attributes'] = final_results[ 'decoded_box_attributes' ] if self.detection_generator.get_config()['apply_nms']: outputs.update({ 'detection_boxes': final_results['detection_boxes'], 'detection_scores': final_results['detection_scores'], 'detection_classes': final_results['detection_classes'], 'num_detections': final_results['num_detections'], }) # Users can choose to include the decoded results (boxes before NMS) in # the output tensor dict even if `apply_nms` is set to `True`. if self.detection_generator.get_config()['return_decoded']: _update_decoded_results() else: _update_decoded_results() if raw_attributes: outputs.update({ 'attribute_outputs': raw_attributes, 'detection_attributes': final_results['detection_attributes'], }) return outputs @property def checkpoint_items( self) -> Mapping[str, Union[tf_keras.Model, tf_keras.layers.Layer]]: """Returns a dictionary of items to be additionally checkpointed.""" items = dict(backbone=self.backbone, head=self.head) if self.decoder is not None: items.update(decoder=self.decoder) return items @property def backbone(self) -> tf_keras.Model: return self._backbone @property def decoder(self) -> tf_keras.Model: return self._decoder @property def head(self) -> tf_keras.layers.Layer: return self._head @property def detection_generator(self) -> tf_keras.layers.Layer: return self._detection_generator def get_config(self) -> Mapping[str, Any]: return self._config_dict @classmethod def from_config(cls, config): return cls(**config)