Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""RetinaNet.""" | |
from typing import Any, Mapping, List, Optional, Union, Sequence | |
# Import libraries | |
import tensorflow as tf, tf_keras | |
from official.vision.ops import anchor | |
class RetinaNetModel(tf_keras.Model): | |
"""The RetinaNet model class.""" | |
def __init__(self, | |
backbone: tf_keras.Model, | |
decoder: tf_keras.Model, | |
head: tf_keras.layers.Layer, | |
detection_generator: tf_keras.layers.Layer, | |
min_level: Optional[int] = None, | |
max_level: Optional[int] = None, | |
num_scales: Optional[int] = None, | |
aspect_ratios: Optional[List[float]] = None, | |
anchor_size: Optional[float] = None, | |
**kwargs): | |
"""Detection initialization function. | |
Args: | |
backbone: `tf_keras.Model` a backbone network. | |
decoder: `tf_keras.Model` a decoder network. | |
head: `RetinaNetHead`, the RetinaNet head. | |
detection_generator: the detection generator. | |
min_level: Minimum level in output feature maps. | |
max_level: Maximum level in output feature maps. | |
num_scales: A number representing intermediate scales added | |
on each level. For instances, num_scales=2 adds one additional | |
intermediate anchor scales [2^0, 2^0.5] on each level. | |
aspect_ratios: A list representing the aspect raito | |
anchors added on each level. The number indicates the ratio of width to | |
height. For instances, aspect_ratios=[1.0, 2.0, 0.5] adds three anchors | |
on each scale level. | |
anchor_size: A number representing the scale of size of the base | |
anchor to the feature stride 2^level. | |
**kwargs: keyword arguments to be passed. | |
""" | |
super(RetinaNetModel, self).__init__(**kwargs) | |
self._config_dict = { | |
'backbone': backbone, | |
'decoder': decoder, | |
'head': head, | |
'detection_generator': detection_generator, | |
'min_level': min_level, | |
'max_level': max_level, | |
'num_scales': num_scales, | |
'aspect_ratios': aspect_ratios, | |
'anchor_size': anchor_size, | |
} | |
self._backbone = backbone | |
self._decoder = decoder | |
self._head = head | |
self._detection_generator = detection_generator | |
def call(self, | |
images: Union[tf.Tensor, Sequence[tf.Tensor]], | |
image_shape: Optional[tf.Tensor] = None, | |
anchor_boxes: Optional[Mapping[str, tf.Tensor]] = None, | |
output_intermediate_features: bool = False, | |
training: bool = None) -> Mapping[str, tf.Tensor]: | |
"""Forward pass of the RetinaNet model. | |
Args: | |
images: `Tensor` or a sequence of `Tensor`, the input batched images to | |
the backbone network, whose shape(s) is [batch, height, width, 3]. If it | |
is a sequence of `Tensor`, we will assume the anchors are generated | |
based on the shape of the first image(s). | |
image_shape: `Tensor`, the actual shape of the input images, whose shape | |
is [batch, 2] where the last dimension is [height, width]. Note that | |
this is the actual image shape excluding paddings. For example, images | |
in the batch may be resized into different shapes before padding to the | |
fixed size. | |
anchor_boxes: a dict of tensors which includes multilevel anchors. | |
- key: `str`, the level of the multilevel predictions. | |
- values: `Tensor`, the anchor coordinates of a particular feature | |
level, whose shape is [height_l, width_l, num_anchors_per_location]. | |
output_intermediate_features: `bool` indicating whether to return the | |
intermediate feature maps generated by backbone and decoder. | |
training: `bool`, indicating whether it is in training mode. | |
Returns: | |
scores: a dict of tensors which includes scores of the predictions. | |
- key: `str`, the level of the multilevel predictions. | |
- values: `Tensor`, the box scores predicted from a particular feature | |
level, whose shape is | |
[batch, height_l, width_l, num_classes * num_anchors_per_location]. | |
boxes: a dict of tensors which includes coordinates of the predictions. | |
- key: `str`, the level of the multilevel predictions. | |
- values: `Tensor`, the box coordinates predicted from a particular | |
feature level, whose shape is | |
[batch, height_l, width_l, 4 * num_anchors_per_location]. | |
attributes: a dict of (attribute_name, attribute_predictions). Each | |
attribute prediction is a dict that includes: | |
- key: `str`, the level of the multilevel predictions. | |
- values: `Tensor`, the attribute predictions from a particular | |
feature level, whose shape is | |
[batch, height_l, width_l, att_size * num_anchors_per_location]. | |
""" | |
outputs = {} | |
# Feature extraction. | |
features = self.backbone(images) | |
if output_intermediate_features: | |
outputs.update( | |
{'backbone_{}'.format(k): v for k, v in features.items()}) | |
if self.decoder: | |
features = self.decoder(features) | |
if output_intermediate_features: | |
outputs.update( | |
{'decoder_{}'.format(k): v for k, v in features.items()}) | |
# Dense prediction. `raw_attributes` can be empty. | |
raw_scores, raw_boxes, raw_attributes = self.head(features) | |
if training: | |
outputs.update({ | |
'cls_outputs': raw_scores, | |
'box_outputs': raw_boxes, | |
}) | |
if raw_attributes: | |
outputs.update({'attribute_outputs': raw_attributes}) | |
return outputs | |
else: | |
# Generate anchor boxes for this batch if not provided. | |
if anchor_boxes is None: | |
if isinstance(images, Sequence): | |
primary_images = images[0] | |
elif isinstance(images, tf.Tensor): | |
primary_images = images | |
else: | |
raise ValueError( | |
'Input should be a tf.Tensor or a sequence of tf.Tensor, not {}.' | |
.format(type(images))) | |
_, image_height, image_width, _ = primary_images.get_shape().as_list() | |
anchor_boxes = anchor.Anchor( | |
min_level=self._config_dict['min_level'], | |
max_level=self._config_dict['max_level'], | |
num_scales=self._config_dict['num_scales'], | |
aspect_ratios=self._config_dict['aspect_ratios'], | |
anchor_size=self._config_dict['anchor_size'], | |
image_size=(image_height, image_width)).multilevel_boxes | |
for l in anchor_boxes: | |
anchor_boxes[l] = tf.tile( | |
tf.expand_dims(anchor_boxes[l], axis=0), | |
[tf.shape(primary_images)[0], 1, 1, 1]) | |
# Post-processing. | |
final_results = self.detection_generator(raw_boxes, raw_scores, | |
anchor_boxes, image_shape, | |
raw_attributes) | |
outputs.update({ | |
'cls_outputs': raw_scores, | |
'box_outputs': raw_boxes, | |
}) | |
def _update_decoded_results(): | |
outputs.update({ | |
'decoded_boxes': final_results['decoded_boxes'], | |
'decoded_box_scores': final_results['decoded_box_scores'], | |
}) | |
if final_results.get('decoded_box_attributes') is not None: | |
outputs['decoded_box_attributes'] = final_results[ | |
'decoded_box_attributes' | |
] | |
if self.detection_generator.get_config()['apply_nms']: | |
outputs.update({ | |
'detection_boxes': final_results['detection_boxes'], | |
'detection_scores': final_results['detection_scores'], | |
'detection_classes': final_results['detection_classes'], | |
'num_detections': final_results['num_detections'], | |
}) | |
# Users can choose to include the decoded results (boxes before NMS) in | |
# the output tensor dict even if `apply_nms` is set to `True`. | |
if self.detection_generator.get_config()['return_decoded']: | |
_update_decoded_results() | |
else: | |
_update_decoded_results() | |
if raw_attributes: | |
outputs.update({ | |
'attribute_outputs': raw_attributes, | |
'detection_attributes': final_results['detection_attributes'], | |
}) | |
return outputs | |
def checkpoint_items( | |
self) -> Mapping[str, Union[tf_keras.Model, tf_keras.layers.Layer]]: | |
"""Returns a dictionary of items to be additionally checkpointed.""" | |
items = dict(backbone=self.backbone, head=self.head) | |
if self.decoder is not None: | |
items.update(decoder=self.decoder) | |
return items | |
def backbone(self) -> tf_keras.Model: | |
return self._backbone | |
def decoder(self) -> tf_keras.Model: | |
return self._decoder | |
def head(self) -> tf_keras.layers.Layer: | |
return self._head | |
def detection_generator(self) -> tf_keras.layers.Layer: | |
return self._detection_generator | |
def get_config(self) -> Mapping[str, Any]: | |
return self._config_dict | |
def from_config(cls, config): | |
return cls(**config) | |