# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Model defination for the RetinaNet Model.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf, tf_keras from official.legacy.detection.dataloader import mode_keys from official.legacy.detection.evaluation import factory as eval_factory from official.legacy.detection.modeling import base_model from official.legacy.detection.modeling import losses from official.legacy.detection.modeling.architecture import factory from official.legacy.detection.ops import postprocess_ops class RetinanetModel(base_model.Model): """RetinaNet model function.""" def __init__(self, params): super(RetinanetModel, self).__init__(params) # For eval metrics. self._params = params # Architecture generators. self._backbone_fn = factory.backbone_generator(params) self._fpn_fn = factory.multilevel_features_generator(params) self._head_fn = factory.retinanet_head_generator(params) # Loss function. self._cls_loss_fn = losses.RetinanetClassLoss( params.retinanet_loss, params.architecture.num_classes) self._box_loss_fn = losses.RetinanetBoxLoss(params.retinanet_loss) self._box_loss_weight = params.retinanet_loss.box_loss_weight self._keras_model = None # Predict function. self._generate_detections_fn = postprocess_ops.MultilevelDetectionGenerator( params.architecture.min_level, params.architecture.max_level, params.postprocess) self._transpose_input = params.train.transpose_input assert not self._transpose_input, 'Transpose input is not supported.' # Input layer. self._input_layer = tf_keras.layers.Input( shape=(None, None, params.retinanet_parser.num_channels), name='', dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32) def build_outputs(self, inputs, mode): # If the input image is transposed (from NHWC to HWCN), we need to revert it # back to the original shape before it's used in the computation. if self._transpose_input: inputs = tf.transpose(inputs, [3, 0, 1, 2]) backbone_features = self._backbone_fn( inputs, is_training=(mode == mode_keys.TRAIN)) fpn_features = self._fpn_fn( backbone_features, is_training=(mode == mode_keys.TRAIN)) cls_outputs, box_outputs = self._head_fn( fpn_features, is_training=(mode == mode_keys.TRAIN)) if self._use_bfloat16: levels = cls_outputs.keys() for level in levels: cls_outputs[level] = tf.cast(cls_outputs[level], tf.float32) box_outputs[level] = tf.cast(box_outputs[level], tf.float32) model_outputs = { 'cls_outputs': cls_outputs, 'box_outputs': box_outputs, } return model_outputs def build_loss_fn(self): if self._keras_model is None: raise ValueError('build_loss_fn() must be called after build_model().') filter_fn = self.make_filter_trainable_variables_fn() trainable_variables = filter_fn(self._keras_model.trainable_variables) def _total_loss_fn(labels, outputs): cls_loss = self._cls_loss_fn(outputs['cls_outputs'], labels['cls_targets'], labels['num_positives']) box_loss = self._box_loss_fn(outputs['box_outputs'], labels['box_targets'], labels['num_positives']) model_loss = cls_loss + self._box_loss_weight * box_loss l2_regularization_loss = self.weight_decay_loss(trainable_variables) total_loss = model_loss + l2_regularization_loss return { 'total_loss': total_loss, 'cls_loss': cls_loss, 'box_loss': box_loss, 'model_loss': model_loss, 'l2_regularization_loss': l2_regularization_loss, } return _total_loss_fn def build_model(self, params, mode=None): if self._keras_model is None: outputs = self.model_outputs(self._input_layer, mode) model = tf_keras.models.Model( inputs=self._input_layer, outputs=outputs, name='retinanet') assert model is not None, 'Fail to build tf_keras.Model.' model.optimizer = self.build_optimizer() self._keras_model = model return self._keras_model def post_processing(self, labels, outputs): # TODO(yeqing): Moves the output related part into build_outputs. required_output_fields = ['cls_outputs', 'box_outputs'] for field in required_output_fields: if field not in outputs: raise ValueError('"%s" is missing in outputs, requried %s found %s' % (field, required_output_fields, outputs.keys())) required_label_fields = ['image_info', 'groundtruths'] for field in required_label_fields: if field not in labels: raise ValueError('"%s" is missing in outputs, requried %s found %s' % (field, required_label_fields, labels.keys())) boxes, scores, classes, valid_detections = self._generate_detections_fn( outputs['box_outputs'], outputs['cls_outputs'], labels['anchor_boxes'], labels['image_info'][:, 1:2, :]) # Discards the old output tensors to save memory. The `cls_outputs` and # `box_outputs` are pretty big and could potentiall lead to memory issue. outputs = { 'source_id': labels['groundtruths']['source_id'], 'image_info': labels['image_info'], 'num_detections': valid_detections, 'detection_boxes': boxes, 'detection_classes': classes, 'detection_scores': scores, } if 'groundtruths' in labels: labels['source_id'] = labels['groundtruths']['source_id'] labels['boxes'] = labels['groundtruths']['boxes'] labels['classes'] = labels['groundtruths']['classes'] labels['areas'] = labels['groundtruths']['areas'] labels['is_crowds'] = labels['groundtruths']['is_crowds'] return labels, outputs def eval_metrics(self): return eval_factory.evaluator_generator(self._params.eval)