deanna-emery's picture
updates
93528c6
raw
history blame
9.71 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RetinaNet."""
from typing import Any, Mapping, List, Optional, Union, Sequence
# Import libraries
import tensorflow as tf, tf_keras
from official.vision.ops import anchor
@tf_keras.utils.register_keras_serializable(package='Vision')
class RetinaNetModel(tf_keras.Model):
"""The RetinaNet model class."""
def __init__(self,
backbone: tf_keras.Model,
decoder: tf_keras.Model,
head: tf_keras.layers.Layer,
detection_generator: tf_keras.layers.Layer,
min_level: Optional[int] = None,
max_level: Optional[int] = None,
num_scales: Optional[int] = None,
aspect_ratios: Optional[List[float]] = None,
anchor_size: Optional[float] = None,
**kwargs):
"""Detection initialization function.
Args:
backbone: `tf_keras.Model` a backbone network.
decoder: `tf_keras.Model` a decoder network.
head: `RetinaNetHead`, the RetinaNet head.
detection_generator: the detection generator.
min_level: Minimum level in output feature maps.
max_level: Maximum level in output feature maps.
num_scales: A number representing intermediate scales added
on each level. For instances, num_scales=2 adds one additional
intermediate anchor scales [2^0, 2^0.5] on each level.
aspect_ratios: A list representing the aspect raito
anchors added on each level. The number indicates the ratio of width to
height. For instances, aspect_ratios=[1.0, 2.0, 0.5] adds three anchors
on each scale level.
anchor_size: A number representing the scale of size of the base
anchor to the feature stride 2^level.
**kwargs: keyword arguments to be passed.
"""
super(RetinaNetModel, self).__init__(**kwargs)
self._config_dict = {
'backbone': backbone,
'decoder': decoder,
'head': head,
'detection_generator': detection_generator,
'min_level': min_level,
'max_level': max_level,
'num_scales': num_scales,
'aspect_ratios': aspect_ratios,
'anchor_size': anchor_size,
}
self._backbone = backbone
self._decoder = decoder
self._head = head
self._detection_generator = detection_generator
def call(self,
images: Union[tf.Tensor, Sequence[tf.Tensor]],
image_shape: Optional[tf.Tensor] = None,
anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None,
output_intermediate_features: bool = False,
training: bool = None) -> Mapping[str, tf.Tensor]:
"""Forward pass of the RetinaNet model.
Args:
images: `Tensor` or a sequence of `Tensor`, the input batched images to
the backbone network, whose shape(s) is [batch, height, width, 3]. If it
is a sequence of `Tensor`, we will assume the anchors are generated
based on the shape of the first image(s).
image_shape: `Tensor`, the actual shape of the input images, whose shape
is [batch, 2] where the last dimension is [height, width]. Note that
this is the actual image shape excluding paddings. For example, images
in the batch may be resized into different shapes before padding to the
fixed size.
anchor_boxes: a dict of tensors which includes multilevel anchors.
- key: `str`, the level of the multilevel predictions.
- values: `Tensor`, the anchor coordinates of a particular feature
level, whose shape is [height_l, width_l, num_anchors_per_location].
output_intermediate_features: `bool` indicating whether to return the
intermediate feature maps generated by backbone and decoder.
training: `bool`, indicating whether it is in training mode.
Returns:
scores: a dict of tensors which includes scores of the predictions.
- key: `str`, the level of the multilevel predictions.
- values: `Tensor`, the box scores predicted from a particular feature
level, whose shape is
[batch, height_l, width_l, num_classes * num_anchors_per_location].
boxes: a dict of tensors which includes coordinates of the predictions.
- key: `str`, the level of the multilevel predictions.
- values: `Tensor`, the box coordinates predicted from a particular
feature level, whose shape is
[batch, height_l, width_l, 4 * num_anchors_per_location].
attributes: a dict of (attribute_name, attribute_predictions). Each
attribute prediction is a dict that includes:
- key: `str`, the level of the multilevel predictions.
- values: `Tensor`, the attribute predictions from a particular
feature level, whose shape is
[batch, height_l, width_l, att_size * num_anchors_per_location].
"""
outputs = {}
# Feature extraction.
features = self.backbone(images)
if output_intermediate_features:
outputs.update(
{'backbone_{}'.format(k): v for k, v in features.items()})
if self.decoder:
features = self.decoder(features)
if output_intermediate_features:
outputs.update(
{'decoder_{}'.format(k): v for k, v in features.items()})
# Dense prediction. `raw_attributes` can be empty.
raw_scores, raw_boxes, raw_attributes = self.head(features)
if training:
outputs.update({
'cls_outputs': raw_scores,
'box_outputs': raw_boxes,
})
if raw_attributes:
outputs.update({'attribute_outputs': raw_attributes})
return outputs
else:
# Generate anchor boxes for this batch if not provided.
if anchor_boxes is None:
if isinstance(images, Sequence):
primary_images = images[0]
elif isinstance(images, tf.Tensor):
primary_images = images
else:
raise ValueError(
'Input should be a tf.Tensor or a sequence of tf.Tensor, not {}.'
.format(type(images)))
_, image_height, image_width, _ = primary_images.get_shape().as_list()
anchor_boxes = anchor.Anchor(
min_level=self._config_dict['min_level'],
max_level=self._config_dict['max_level'],
num_scales=self._config_dict['num_scales'],
aspect_ratios=self._config_dict['aspect_ratios'],
anchor_size=self._config_dict['anchor_size'],
image_size=(image_height, image_width)).multilevel_boxes
for l in anchor_boxes:
anchor_boxes[l] = tf.tile(
tf.expand_dims(anchor_boxes[l], axis=0),
[tf.shape(primary_images)[0], 1, 1, 1])
# Post-processing.
final_results = self.detection_generator(raw_boxes, raw_scores,
anchor_boxes, image_shape,
raw_attributes)
outputs.update({
'cls_outputs': raw_scores,
'box_outputs': raw_boxes,
})
def _update_decoded_results():
outputs.update({
'decoded_boxes': final_results['decoded_boxes'],
'decoded_box_scores': final_results['decoded_box_scores'],
})
if final_results.get('decoded_box_attributes') is not None:
outputs['decoded_box_attributes'] = final_results[
'decoded_box_attributes'
]
if self.detection_generator.get_config()['apply_nms']:
outputs.update({
'detection_boxes': final_results['detection_boxes'],
'detection_scores': final_results['detection_scores'],
'detection_classes': final_results['detection_classes'],
'num_detections': final_results['num_detections'],
})
# Users can choose to include the decoded results (boxes before NMS) in
# the output tensor dict even if `apply_nms` is set to `True`.
if self.detection_generator.get_config()['return_decoded']:
_update_decoded_results()
else:
_update_decoded_results()
if raw_attributes:
outputs.update({
'attribute_outputs': raw_attributes,
'detection_attributes': final_results['detection_attributes'],
})
return outputs
@property
def checkpoint_items(
self) -> Mapping[str, Union[tf_keras.Model, tf_keras.layers.Layer]]:
"""Returns a dictionary of items to be additionally checkpointed."""
items = dict(backbone=self.backbone, head=self.head)
if self.decoder is not None:
items.update(decoder=self.decoder)
return items
@property
def backbone(self) -> tf_keras.Model:
return self._backbone
@property
def decoder(self) -> tf_keras.Model:
return self._decoder
@property
def head(self) -> tf_keras.layers.Layer:
return self._head
@property
def detection_generator(self) -> tf_keras.layers.Layer:
return self._detection_generator
def get_config(self) -> Mapping[str, Any]:
return self._config_dict
@classmethod
def from_config(cls, config):
return cls(**config)