# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Detection input and model functions for serving/inference.""" import math from typing import Mapping, Tuple from absl import logging import tensorflow as tf, tf_keras from official.vision import configs from official.vision.modeling import factory from official.vision.ops import anchor from official.vision.ops import box_ops from official.vision.ops import preprocess_ops from official.vision.serving import export_base class DetectionModule(export_base.ExportModule): """Detection Module.""" @property def _padded_size(self): if self.params.task.train_data.parser.pad: return preprocess_ops.compute_padded_size( self._input_image_size, 2**self.params.task.model.max_level ) else: return self._input_image_size def _build_model(self): nms_versions_supporting_dynamic_batch_size = {'batched', 'v2', 'v3'} nms_version = self.params.task.model.detection_generator.nms_version if (self._batch_size is None and nms_version not in nms_versions_supporting_dynamic_batch_size): logging.info('nms_version is set to `batched` because `%s` ' 'does not support with dynamic batch size.', nms_version) self.params.task.model.detection_generator.nms_version = 'batched' input_specs = tf_keras.layers.InputSpec(shape=[ self._batch_size, *self._padded_size, 3]) if isinstance(self.params.task.model, configs.maskrcnn.MaskRCNN): model = factory.build_maskrcnn( input_specs=input_specs, model_config=self.params.task.model) elif isinstance(self.params.task.model, configs.retinanet.RetinaNet): model = factory.build_retinanet( input_specs=input_specs, model_config=self.params.task.model) else: raise ValueError('Detection module not implemented for {} model.'.format( type(self.params.task.model))) return model def _build_anchor_boxes(self): """Builds and returns anchor boxes.""" model_params = self.params.task.model input_anchor = anchor.build_anchor_generator( min_level=model_params.min_level, max_level=model_params.max_level, num_scales=model_params.anchor.num_scales, aspect_ratios=model_params.anchor.aspect_ratios, anchor_size=model_params.anchor.anchor_size) return input_anchor(image_size=self._padded_size) def _build_inputs(self, image): """Builds detection model inputs for serving.""" # Normalizes image with mean and std pixel values. image = preprocess_ops.normalize_image( image, offset=preprocess_ops.MEAN_RGB, scale=preprocess_ops.STDDEV_RGB) image, image_info = preprocess_ops.resize_and_crop_image( image, self._input_image_size, padded_size=self._padded_size, aug_scale_min=1.0, aug_scale_max=1.0, keep_aspect_ratio=self.params.task.train_data.parser.keep_aspect_ratio, ) anchor_boxes = self._build_anchor_boxes() return image, anchor_boxes, image_info def _normalize_coordinates(self, detections_dict, dict_keys, image_info): """Normalizes detection coordinates between 0 and 1. Args: detections_dict: Dictionary containing the output of the model prediction. dict_keys: Key names corresponding to the tensors of the output dictionary that we want to update. image_info: Tensor containing the details of the image resizing. Returns: detections_dict: Updated detection dictionary. """ for key in dict_keys: if key not in detections_dict: continue detection_boxes = detections_dict[key] / tf.tile( image_info[:, 2:3, :], [1, 1, 2] ) detections_dict[key] = box_ops.normalize_boxes( detection_boxes, image_info[:, 0:1, :] ) detections_dict[key] = tf.clip_by_value(detections_dict[key], 0.0, 1.0) return detections_dict def preprocess( self, images: tf.Tensor ) -> Tuple[tf.Tensor, Mapping[str, tf.Tensor], tf.Tensor]: """Preprocesses inputs to be suitable for the model. Args: images: The images tensor. Returns: images: The images tensor cast to float. anchor_boxes: Dict mapping anchor levels to anchor boxes. image_info: Tensor containing the details of the image resizing. """ model_params = self.params.task.model with tf.device('cpu:0'): images = tf.cast(images, dtype=tf.float32) # Tensor Specs for map_fn outputs (images, anchor_boxes, and image_info). images_spec = tf.TensorSpec(shape=self._padded_size + [3], dtype=tf.float32) num_anchors = model_params.anchor.num_scales * len( model_params.anchor.aspect_ratios) * 4 anchor_shapes = [] for level in range(model_params.min_level, model_params.max_level + 1): anchor_level_spec = tf.TensorSpec( shape=[ math.ceil(self._padded_size[0] / 2**level), math.ceil(self._padded_size[1] / 2**level), num_anchors, ], dtype=tf.float32) anchor_shapes.append((str(level), anchor_level_spec)) image_info_spec = tf.TensorSpec(shape=[4, 2], dtype=tf.float32) images, anchor_boxes, image_info = tf.nest.map_structure( tf.identity, tf.map_fn( self._build_inputs, elems=images, fn_output_signature=(images_spec, dict(anchor_shapes), image_info_spec), parallel_iterations=32)) return images, anchor_boxes, image_info def serve(self, images: tf.Tensor): """Casts image to float and runs inference. Args: images: uint8 Tensor of shape [batch_size, None, None, 3] Returns: Tensor holding detection output logits. """ # Skip image preprocessing when input_type is tflite so it is compatible # with TFLite quantization. if self._input_type != 'tflite': images, anchor_boxes, image_info = self.preprocess(images) else: with tf.device('cpu:0'): anchor_boxes = self._build_anchor_boxes() # image_info is a 3D tensor of shape [batch_size, 4, 2]. It is in the # format of [[original_height, original_width], # [desired_height, desired_width], [y_scale, x_scale], # [y_offset, x_offset]]. When input_type is tflite, input image is # supposed to be preprocessed already. image_info = tf.convert_to_tensor([[ self._input_image_size, self._input_image_size, [1.0, 1.0], [0, 0] ]], dtype=tf.float32) input_image_shape = image_info[:, 1, :] # To overcome keras.Model extra limitation to save a model with layers that # have multiple inputs, we use `model.call` here to trigger the forward # path. Note that, this disables some keras magics happens in `__call__`. model_call_kwargs = { 'images': images, 'image_shape': input_image_shape, 'anchor_boxes': anchor_boxes, 'training': False, } if isinstance(self.params.task.model, configs.retinanet.RetinaNet): model_call_kwargs['output_intermediate_features'] = ( self.params.task.export_config.output_intermediate_features ) detections = self.model.call(**model_call_kwargs) if self.params.task.model.detection_generator.apply_nms: # For RetinaNet model, apply export_config. # TODO(huizhongc): Add export_config to fasterrcnn and maskrcnn as needed. if isinstance(self.params.task.model, configs.retinanet.RetinaNet): export_config = self.params.task.export_config # Normalize detection box coordinates to [0, 1]. if export_config.output_normalized_coordinates: keys = ['detection_boxes', 'detection_outer_boxes'] detections = self._normalize_coordinates(detections, keys, image_info) # Cast num_detections and detection_classes to float. This allows the # model inference to work on chain (go/chain) as chain requires floating # point outputs. if export_config.cast_num_detections_to_float: detections['num_detections'] = tf.cast( detections['num_detections'], dtype=tf.float32) if export_config.cast_detection_classes_to_float: detections['detection_classes'] = tf.cast( detections['detection_classes'], dtype=tf.float32) final_outputs = { 'detection_boxes': detections['detection_boxes'], 'detection_scores': detections['detection_scores'], 'detection_classes': detections['detection_classes'], 'num_detections': detections['num_detections'] } if 'detection_outer_boxes' in detections: final_outputs['detection_outer_boxes'] = ( detections['detection_outer_boxes']) else: # For RetinaNet model, apply export_config. if isinstance(self.params.task.model, configs.retinanet.RetinaNet): export_config = self.params.task.export_config # Normalize detection box coordinates to [0, 1]. if export_config.output_normalized_coordinates: keys = ['decoded_boxes'] detections = self._normalize_coordinates(detections, keys, image_info) final_outputs = { 'decoded_boxes': detections['decoded_boxes'], 'decoded_box_scores': detections['decoded_box_scores'] } if 'detection_masks' in detections.keys(): final_outputs['detection_masks'] = detections['detection_masks'] if ( isinstance(self.params.task.model, configs.retinanet.RetinaNet) and self.params.task.export_config.output_intermediate_features ): final_outputs.update( { k: v for k, v in detections.items() if k.startswith('backbone_') or k.startswith('decoder_') } ) if self.params.task.model.detection_generator.nms_version != 'tflite': final_outputs.update({'image_info': image_info}) return final_outputs