deanna-emery's picture
updates
93528c6
raw
history blame
10.7 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Detection input and model functions for serving/inference."""
import math
from typing import Mapping, Tuple
from absl import logging
import tensorflow as tf, tf_keras
from official.vision import configs
from official.vision.modeling import factory
from official.vision.ops import anchor
from official.vision.ops import box_ops
from official.vision.ops import preprocess_ops
from official.vision.serving import export_base
class DetectionModule(export_base.ExportModule):
"""Detection Module."""
@property
def _padded_size(self):
if self.params.task.train_data.parser.pad:
return preprocess_ops.compute_padded_size(
self._input_image_size, 2**self.params.task.model.max_level
)
else:
return self._input_image_size
def _build_model(self):
nms_versions_supporting_dynamic_batch_size = {'batched', 'v2', 'v3'}
nms_version = self.params.task.model.detection_generator.nms_version
if (self._batch_size is None and
nms_version not in nms_versions_supporting_dynamic_batch_size):
logging.info('nms_version is set to `batched` because `%s` '
'does not support with dynamic batch size.', nms_version)
self.params.task.model.detection_generator.nms_version = 'batched'
input_specs = tf_keras.layers.InputSpec(shape=[
self._batch_size, *self._padded_size, 3])
if isinstance(self.params.task.model, configs.maskrcnn.MaskRCNN):
model = factory.build_maskrcnn(
input_specs=input_specs, model_config=self.params.task.model)
elif isinstance(self.params.task.model, configs.retinanet.RetinaNet):
model = factory.build_retinanet(
input_specs=input_specs, model_config=self.params.task.model)
else:
raise ValueError('Detection module not implemented for {} model.'.format(
type(self.params.task.model)))
return model
def _build_anchor_boxes(self):
"""Builds and returns anchor boxes."""
model_params = self.params.task.model
input_anchor = anchor.build_anchor_generator(
min_level=model_params.min_level,
max_level=model_params.max_level,
num_scales=model_params.anchor.num_scales,
aspect_ratios=model_params.anchor.aspect_ratios,
anchor_size=model_params.anchor.anchor_size)
return input_anchor(image_size=self._padded_size)
def _build_inputs(self, image):
"""Builds detection model inputs for serving."""
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(
image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB)
image, image_info = preprocess_ops.resize_and_crop_image(
image,
self._input_image_size,
padded_size=self._padded_size,
aug_scale_min=1.0,
aug_scale_max=1.0,
keep_aspect_ratio=self.params.task.train_data.parser.keep_aspect_ratio,
)
anchor_boxes = self._build_anchor_boxes()
return image, anchor_boxes, image_info
def _normalize_coordinates(self, detections_dict, dict_keys, image_info):
"""Normalizes detection coordinates between 0 and 1.
Args:
detections_dict: Dictionary containing the output of the model prediction.
dict_keys: Key names corresponding to the tensors of the output dictionary
that we want to update.
image_info: Tensor containing the details of the image resizing.
Returns:
detections_dict: Updated detection dictionary.
"""
for key in dict_keys:
if key not in detections_dict:
continue
detection_boxes = detections_dict[key] / tf.tile(
image_info[:, 2:3, :], [1, 1, 2]
)
detections_dict[key] = box_ops.normalize_boxes(
detection_boxes, image_info[:, 0:1, :]
)
detections_dict[key] = tf.clip_by_value(detections_dict[key], 0.0, 1.0)
return detections_dict
def preprocess(
self, images: tf.Tensor
) -> Tuple[tf.Tensor, Mapping[str, tf.Tensor], tf.Tensor]:
"""Preprocesses inputs to be suitable for the model.
Args:
images: The images tensor.
Returns:
images: The images tensor cast to float.
anchor_boxes: Dict mapping anchor levels to anchor boxes.
image_info: Tensor containing the details of the image resizing.
"""
model_params = self.params.task.model
with tf.device('cpu:0'):
images = tf.cast(images, dtype=tf.float32)
# Tensor Specs for map_fn outputs (images, anchor_boxes, and image_info).
images_spec = tf.TensorSpec(shape=self._padded_size + [3],
dtype=tf.float32)
num_anchors = model_params.anchor.num_scales * len(
model_params.anchor.aspect_ratios) * 4
anchor_shapes = []
for level in range(model_params.min_level, model_params.max_level + 1):
anchor_level_spec = tf.TensorSpec(
shape=[
math.ceil(self._padded_size[0] / 2**level),
math.ceil(self._padded_size[1] / 2**level),
num_anchors,
],
dtype=tf.float32)
anchor_shapes.append((str(level), anchor_level_spec))
image_info_spec = tf.TensorSpec(shape=[4, 2], dtype=tf.float32)
images, anchor_boxes, image_info = tf.nest.map_structure(
tf.identity,
tf.map_fn(
self._build_inputs,
elems=images,
fn_output_signature=(images_spec, dict(anchor_shapes),
image_info_spec),
parallel_iterations=32))
return images, anchor_boxes, image_info
def serve(self, images: tf.Tensor):
"""Casts image to float and runs inference.
Args:
images: uint8 Tensor of shape [batch_size, None, None, 3]
Returns:
Tensor holding detection output logits.
"""
# Skip image preprocessing when input_type is tflite so it is compatible
# with TFLite quantization.
if self._input_type != 'tflite':
images, anchor_boxes, image_info = self.preprocess(images)
else:
with tf.device('cpu:0'):
anchor_boxes = self._build_anchor_boxes()
# image_info is a 3D tensor of shape [batch_size, 4, 2]. It is in the
# format of [[original_height, original_width],
# [desired_height, desired_width], [y_scale, x_scale],
# [y_offset, x_offset]]. When input_type is tflite, input image is
# supposed to be preprocessed already.
image_info = tf.convert_to_tensor([[
self._input_image_size, self._input_image_size, [1.0, 1.0], [0, 0]
]],
dtype=tf.float32)
input_image_shape = image_info[:, 1, :]
# To overcome keras.Model extra limitation to save a model with layers that
# have multiple inputs, we use `model.call` here to trigger the forward
# path. Note that, this disables some keras magics happens in `__call__`.
model_call_kwargs = {
'images': images,
'image_shape': input_image_shape,
'anchor_boxes': anchor_boxes,
'training': False,
}
if isinstance(self.params.task.model, configs.retinanet.RetinaNet):
model_call_kwargs['output_intermediate_features'] = (
self.params.task.export_config.output_intermediate_features
)
detections = self.model.call(**model_call_kwargs)
if self.params.task.model.detection_generator.apply_nms:
# For RetinaNet model, apply export_config.
# TODO(huizhongc): Add export_config to fasterrcnn and maskrcnn as needed.
if isinstance(self.params.task.model, configs.retinanet.RetinaNet):
export_config = self.params.task.export_config
# Normalize detection box coordinates to [0, 1].
if export_config.output_normalized_coordinates:
keys = ['detection_boxes', 'detection_outer_boxes']
detections = self._normalize_coordinates(detections, keys, image_info)
# Cast num_detections and detection_classes to float. This allows the
# model inference to work on chain (go/chain) as chain requires floating
# point outputs.
if export_config.cast_num_detections_to_float:
detections['num_detections'] = tf.cast(
detections['num_detections'], dtype=tf.float32)
if export_config.cast_detection_classes_to_float:
detections['detection_classes'] = tf.cast(
detections['detection_classes'], dtype=tf.float32)
final_outputs = {
'detection_boxes': detections['detection_boxes'],
'detection_scores': detections['detection_scores'],
'detection_classes': detections['detection_classes'],
'num_detections': detections['num_detections']
}
if 'detection_outer_boxes' in detections:
final_outputs['detection_outer_boxes'] = (
detections['detection_outer_boxes'])
else:
# For RetinaNet model, apply export_config.
if isinstance(self.params.task.model, configs.retinanet.RetinaNet):
export_config = self.params.task.export_config
# Normalize detection box coordinates to [0, 1].
if export_config.output_normalized_coordinates:
keys = ['decoded_boxes']
detections = self._normalize_coordinates(detections, keys, image_info)
final_outputs = {
'decoded_boxes': detections['decoded_boxes'],
'decoded_box_scores': detections['decoded_box_scores']
}
if 'detection_masks' in detections.keys():
final_outputs['detection_masks'] = detections['detection_masks']
if (
isinstance(self.params.task.model, configs.retinanet.RetinaNet)
and self.params.task.export_config.output_intermediate_features
):
final_outputs.update(
{
k: v
for k, v in detections.items()
if k.startswith('backbone_') or k.startswith('decoder_')
}
)
if self.params.task.model.detection_generator.nms_version != 'tflite':
final_outputs.update({'image_info': image_info})
return final_outputs